LangChain跨会话记忆恢复技术源码解析
I. 跨会话记忆恢复技术概述
1.1 技术背景与重要性
LangChain作为一个强大的语言模型应用框架,其跨会话记忆恢复技术是实现连贯、上下文感知对话的核心能力之一。在多轮对话场景中,系统需要能够记住之前的交互内容,以便理解当前问题的上下文,提供更准确、连贯的回答。这种能力对于智能客服、聊天机器人、虚拟助手等应用至关重要,直接影响用户体验和系统实用性。
传统的对话系统通常只能处理单次交互,无法有效地跟踪和利用历史对话信息。而LangChain的跨会话记忆恢复技术通过巧妙的设计和实现,使得系统能够在不同会话之间保持上下文连贯性,为用户提供更加自然、流畅的交互体验。
1.2 核心概念与目标
跨会话记忆恢复技术的核心概念包括记忆存储、记忆检索和记忆更新。记忆存储负责将对话历史以某种格式保存下来,记忆检索则是在需要时从存储中提取相关信息,而记忆更新则是随着对话的进行不断调整和完善记忆内容。
该技术的主要目标包括:
- 实现对话历史的持久化存储,确保会话中断后能够恢复高效地检索与当前问题相关的历史信息智能地管理记忆容量,避免存储过多无关信息支持不同类型的记忆存储方式,如向量数据库、键值存储等
1.3 技术挑战与解决方案
跨会话记忆恢复技术面临着诸多挑战,包括:
- 记忆容量管理:如何在有限的存储资源下保存最有价值的对话历史高效检索:如何快速从大量历史信息中找到相关内容上下文理解:如何准确理解当前问题与历史对话的关联隐私保护:如何确保用户对话内容的安全性和隐私性
LangChain通过一系列创新的解决方案来应对这些挑战,包括:
- 使用向量嵌入技术将文本转换为高维向量,便于高效检索和相似度计算实现分层记忆存储策略,根据对话的重要性和时效性进行不同级别的存储采用注意力机制,在检索时聚焦于最相关的历史信息提供灵活的记忆清除和过期机制,管理记忆容量集成加密和访问控制技术,保护用户隐私
II. 记忆接口设计与核心组件
2.1 记忆接口定义
LangChain的记忆功能通过定义明确的接口来实现,这些接口为不同类型的记忆实现提供了统一的规范。核心的记忆接口包括BaseMemory
类,它定义了记忆系统的基本操作:
from abc import ABC, abstractmethodfrom typing import Dict, List, Any, Optionalclass BaseMemory(ABC): """记忆抽象基类,定义记忆系统的核心接口""" @abstractmethod def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """加载与当前输入相关的记忆变量""" pass @abstractmethod def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None: """保存当前上下文到记忆中""" pass @abstractmethod def clear(self) -> None: """清除记忆内容""" pass
这个接口定义了三个核心方法:
load_memory_variables
:根据当前输入加载相关的历史记忆save_context
:将当前对话上下文保存到记忆中clear
:清除所有记忆内容2.2 核心记忆组件
LangChain提供了多种具体的记忆实现,以满足不同场景的需求。这些实现包括:
2.2.1 简单内存记忆(SimpleMemory)
class SimpleMemory(BaseMemory): """简单的内存记忆实现,使用列表存储对话历史""" def __init__(self): """初始化简单内存记忆""" self.chat_memory = [] def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """加载所有对话历史""" return {"history": self.chat_memory} def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None: """保存当前对话上下文""" # 从输入中提取用户消息 user_message = inputs.get("input", "") # 从输出中提取AI回复 ai_message = outputs.get("output", "") # 添加到对话历史 self.chat_memory.append({"input": user_message, "output": ai_message}) def clear(self) -> None: """清除对话历史""" self.chat_memory = []
SimpleMemory是最基本的记忆实现,它使用Python列表在内存中存储对话历史。这种实现简单直接,但缺点是会话结束后记忆会丢失,且无法在不同实例之间共享。
2.2.2 聊天消息历史记忆(ChatMessageHistory)
from langchain.schema import BaseMessage, HumanMessage, AIMessageclass ChatMessageHistory: """聊天消息历史记录器,管理消息列表""" def __init__(self): """初始化聊天消息历史""" self.messages = [] def add_user_message(self, message: str) -> None: """添加用户消息""" self.messages.append(HumanMessage(content=message)) def add_ai_message(self, message: str) -> None: """添加AI消息""" self.messages.append(AIMessage(content=message)) def clear(self) -> None: """清除所有消息""" self.messages = [] def to_dict(self) -> List[Dict[str, Any]]: """将消息转换为字典列表""" return [{"type": type(msg).__name__, "content": msg.content} for msg in self.messages] @classmethod def from_dict(cls, messages_dict: List[Dict[str, Any]]) -> "ChatMessageHistory": """从字典列表创建聊天消息历史""" history = cls() for msg_dict in messages_dict: msg_type = msg_dict["type"] content = msg_dict["content"] if msg_type == "HumanMessage": history.add_user_message(content) elif msg_type == "AIMessage": history.add_ai_message(content) return history
ChatMessageHistory专门用于管理聊天消息,它使用LangChain的消息模型(HumanMessage和AIMessage)来存储用户和AI的消息。这种实现提供了更结构化的消息管理方式,并支持消息的序列化和反序列化,便于持久化存储。
2.2.3 向量记忆(VectorStoreRetrieverMemory)
from langchain.vectorstores import VectorStorefrom langchain.retrievers import VectorStoreRetrieverfrom langchain.embeddings import Embeddingsclass VectorStoreRetrieverMemory(BaseMemory): """基于向量存储的记忆实现,使用相似度检索相关记忆""" def __init__(self, vectorstore: VectorStore, embeddings: Embeddings, k: int = 5, memory_key: str = "history"): """初始化向量记忆""" self.vectorstore = vectorstore self.embeddings = embeddings self.k = k self.memory_key = memory_key # 创建检索器 self.retriever = VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": k}) def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """加载与当前输入相关的记忆""" query = inputs.get("input", "") # 如果查询为空,返回空历史 if not query: return {self.memory_key: []} # 将查询文本转换为嵌入向量 query_embedding = self.embeddings.embed_query(query) # 检索相关记忆 docs = self.retriever.get_relevant_documents(query) # 提取记忆文本 history = [doc.page_content for doc in docs] return {self.memory_key: history} def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None: """保存当前上下文到向量存储""" # 从输入中提取用户消息 user_message = inputs.get("input", "") # 从输出中提取AI回复 ai_message = outputs.get("output", "") # 组合成完整的对话消息 full_message = f"User: {user_message}\nAI: {ai_message}" # 添加到向量存储 self.vectorstore.add_texts([full_message]) def clear(self) -> None: """清除所有记忆""" # 删除向量存储中的所有文档 self.vectorstore.delete_documents()
VectorStoreRetrieverMemory是一种更高级的记忆实现,它使用向量数据库来存储和检索记忆。这种实现的优势在于能够根据语义相似度检索相关历史信息,而不仅仅是基于文本匹配。它通过将文本转换为向量表示,利用向量数据库的相似度检索能力,找到与当前问题最相关的历史对话。
2.3 记忆管理器
为了更灵活地管理不同类型的记忆,LangChain提供了记忆管理器(MemoryManager):
class MemoryManager: """记忆管理器,管理多个记忆组件""" def __init__(self): """初始化记忆管理器""" self.memories = {} def add_memory(self, name: str, memory: BaseMemory) -> None: """添加记忆组件""" self.memories[name] = memory def get_memory(self, name: str) -> Optional[BaseMemory]: """获取记忆组件""" return self.memories.get(name) def load_all_memories(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """加载所有记忆变量""" all_memory_vars = {} for name, memory in self.memories.items(): memory_vars = memory.load_memory_variables(inputs) all_memory_vars.update(memory_vars) return all_memory_vars def save_all_contexts(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None: """保存所有记忆上下文""" for name, memory in self.memories.items(): memory.save_context(inputs, outputs) def clear_all_memories(self) -> None: """清除所有记忆""" for name, memory in self.memories.items(): memory.clear()
记忆管理器允许应用同时使用多种类型的记忆组件,并统一管理它们的加载、保存和清除操作。这对于复杂应用场景非常有用,可以根据不同的需求选择合适的记忆实现。
III. 跨会话记忆存储机制
3.1 内存存储实现
内存存储是最简单的存储方式,它将记忆保存在应用的内存中。这种方式的优点是速度快,实现简单,但缺点是会话结束后记忆会丢失,且无法在不同实例之间共享。
LangChain提供了几种内存存储的实现,如前面介绍的SimpleMemory和ChatMessageHistory。这些实现直接使用Python数据结构(如列表、字典)来存储记忆内容。
3.2 持久化存储实现
为了实现跨会话的记忆恢复,需要将记忆持久化存储到磁盘或数据库中。LangChain提供了多种持久化存储的实现。
3.2.1 文件存储
import jsonimport osfrom typing import Dict, Any, Optionalclass FileChatMessageHistory(ChatMessageHistory): """将聊天消息历史保存到文件的实现""" def __init__(self, file_path: str): """初始化文件存储的聊天消息历史""" super().__init__() self.file_path = file_path # 如果文件存在,加载历史消息 if os.path.exists(file_path): self._load_from_file() def _load_from_file(self) -> None: """从文件加载消息历史""" try: with open(self.file_path, "r") as f: messages_dict = json.load(f) self.messages = [] for msg_dict in messages_dict: msg_type = msg_dict["type"] content = msg_dict["content"] if msg_type == "HumanMessage": self.add_user_message(content) elif msg_type == "AIMessage": self.add_ai_message(content) except Exception as e: print(f"Error loading messages from file: {e}") self.messages = [] def add_user_message(self, message: str) -> None: """添加用户消息并保存到文件""" super().add_user_message(message) self._save_to_file() def add_ai_message(self, message: str) -> None: """添加AI消息并保存到文件""" super().add_ai_message(message) self._save_to_file() def clear(self) -> None: """清除消息并删除文件""" super().clear() if os.path.exists(self.file_path): os.remove(self.file_path) def _save_to_file(self) -> None: """将消息保存到文件""" messages_dict = self.to_dict() try: # 创建目录(如果不存在) os.makedirs(os.path.dirname(self.file_path), exist_ok=True) with open(self.file_path, "w") as f: json.dump(messages_dict, f) except Exception as e: print(f"Error saving messages to file: {e}")
FileChatMessageHistory实现了将聊天消息保存到JSON文件的功能。每次添加新消息时,它会将消息追加到文件中;每次初始化时,它会从文件中加载历史消息。这种实现适合小型应用或需要简单持久化的场景。
3.2.2 数据库存储
对于更复杂的应用场景,LangChain提供了基于数据库的存储实现,如SQLite和Redis。
import sqlite3from typing import Dict, Any, Optionalclass SQLiteChatMessageHistory(ChatMessageHistory): """将聊天消息历史保存到SQLite数据库的实现""" def __init__(self, db_path: str, session_id: str): """初始化SQLite存储的聊天消息历史""" super().__init__() self.db_path = db_path self.session_id = session_id # 确保表存在 self._create_table() # 加载历史消息 self._load_from_db() def _create_table(self) -> None: """创建消息表(如果不存在)""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS chat_messages ( session_id TEXT, message_type TEXT, content TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) """) conn.commit() conn.close() def _load_from_db(self) -> None: """从数据库加载消息历史""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT message_type, content FROM chat_messages WHERE session_id = ? ORDER BY timestamp ASC", (self.session_id,) ) rows = cursor.fetchall() for msg_type, content in rows: if msg_type == "HumanMessage": self.add_user_message(content) elif msg_type == "AIMessage": self.add_ai_message(content) conn.close() def add_user_message(self, message: str) -> None: """添加用户消息并保存到数据库""" super().add_user_message(message) self._save_to_db("HumanMessage", message) def add_ai_message(self, message: str) -> None: """添加AI消息并保存到数据库""" super().add_ai_message(message) self._save_to_db("AIMessage", message) def clear(self) -> None: """清除消息并从数据库删除""" super().clear() conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute("DELETE FROM chat_messages WHERE session_id = ?", (self.session_id,)) conn.commit() conn.close() def _save_to_db(self, message_type: str, content: str) -> None: """将消息保存到数据库""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "INSERT INTO chat_messages (session_id, message_type, content) VALUES (?, ?, ?)", (self.session_id, message_type, content) ) conn.commit() conn.close()
SQLiteChatMessageHistory实现了将聊天消息保存到SQLite数据库的功能。它为每个会话创建一个唯一的标识符,并将所有相关消息存储在同一个会话下。这种实现支持跨会话的记忆恢复,并且可以方便地查询和管理历史消息。
3.2.3 Redis存储
import redisimport jsonfrom typing import Dict, Any, Optionalclass RedisChatMessageHistory(ChatMessageHistory): """将聊天消息历史保存到Redis的实现""" def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0, session_id: str = "default", password: str = None): """初始化Redis存储的聊天消息历史""" super().__init__() self.session_id = session_id # 连接Redis self.redis_client = redis.Redis( host=host, port=port, db=db, password=password ) # 加载历史消息 self._load_from_redis() def _load_from_redis(self) -> None: """从Redis加载消息历史""" key = f"chat_history:{self.session_id}" messages_json = self.redis_client.get(key) if messages_json: messages_dict = json.loads(messages_json) for msg_dict in messages_dict: msg_type = msg_dict["type"] content = msg_dict["content"] if msg_type == "HumanMessage": self.add_user_message(content) elif msg_type == "AIMessage": self.add_ai_message(content) def add_user_message(self, message: str) -> None: """添加用户消息并保存到Redis""" super().add_user_message(message) self._save_to_redis() def add_ai_message(self, message: str) -> None: """添加AI消息并保存到Redis""" super().add_ai_message(message) self._save_to_redis() def clear(self) -> None: """清除消息并从Redis删除""" super().clear() key = f"chat_history:{self.session_id}" self.redis_client.delete(key) def _save_to_redis(self) -> None: """将消息保存到Redis""" key = f"chat_history:{self.session_id}" messages_dict = self.to_dict() try: self.redis_client.set(key, json.dumps(messages_dict)) except Exception as e: print(f"Error saving messages to Redis: {e}")
RedisChatMessageHistory实现了将聊天消息保存到Redis的功能。Redis是一种高性能的键值存储数据库,适合存储会话数据。这种实现特别适合分布式应用场景,可以在不同的应用实例之间共享会话记忆。
3.3 向量存储实现
向量存储是跨会话记忆恢复的核心技术之一,它将文本转换为向量表示,并利用向量相似度进行检索。LangChain提供了多种向量存储的实现,支持不同的向量数据库。
3.3.1 Chroma向量存储
from langchain.vectorstores import Chromafrom langchain.embeddings import OpenAIEmbeddingsclass ChromaMemoryStore: """基于Chroma向量数据库的记忆存储""" def __init__(self, persist_directory: str = None, embedding_function: OpenAIEmbeddings = None): """初始化Chroma向量存储""" self.embedding_function = embedding_function or OpenAIEmbeddings() self.persist_directory = persist_directory # 如果指定了持久化目录,尝试从目录加载 if persist_directory: self.vectorstore = Chroma( embedding_function=self.embedding_function, persist_directory=persist_directory ) else: # 创建新的向量存储 self.vectorstore = Chroma(embedding_function=self.embedding_function) def add_memory(self, text: str, metadata: Dict[str, Any] = None) -> None: """添加记忆到向量存储""" # 添加文本到向量存储 self.vectorstore.add_texts([text], metadatas=[metadata] if metadata else None) # 如果指定了持久化目录,持久化向量存储 if self.persist_directory: self.vectorstore.persist() def retrieve_relevant_memories(self, query: str, k: int = 5) -> List[Dict[str, Any]]: """检索与查询相关的记忆""" # 检索相关文档 docs = self.vectorstore.similarity_search(query, k=k) # 提取记忆内容和元数据 memories = [] for doc in docs: memories.append({ "content": doc.page_content, "metadata": doc.metadata }) return memories def delete_memory(self, memory_id: str) -> None: """删除指定ID的记忆""" self.vectorstore.delete([memory_id]) # 如果指定了持久化目录,持久化向量存储 if self.persist_directory: self.vectorstore.persist() def clear_all_memories(self) -> None: """清除所有记忆""" # 删除所有文档 all_ids = self.vectorstore.get()["ids"] if all_ids: self.vectorstore.delete(all_ids) # 如果指定了持久化目录,持久化向量存储 if self.persist_directory: self.vectorstore.persist()
ChromaMemoryStore实现了基于Chroma向量数据库的记忆存储功能。Chroma是一个专门为AI应用设计的向量数据库,提供了简单易用的API和高效的向量检索能力。这种实现适合需要基于语义相似度检索历史记忆的场景。
3.3.2 Pinecone向量存储
import pineconefrom langchain.vectorstores import Pineconefrom langchain.embeddings import OpenAIEmbeddingsclass PineconeMemoryStore: """基于Pinecone向量数据库的记忆存储""" def __init__(self, api_key: str, environment: str, index_name: str, embedding_function: OpenAIEmbeddings = None): """初始化Pinecone向量存储""" # 初始化Pinecone pinecone.init(api_key=api_key, environment=environment) self.embedding_function = embedding_function or OpenAIEmbeddings() self.index_name = index_name # 检查索引是否存在 if index_name not in pinecone.list_indexes(): raise ValueError(f"Pinecone index {index_name} does not exist") # 连接到索引 self.vectorstore = Pinecone( index=pinecone.Index(index_name), embedding_function=self.embedding_function, text_key="text" ) def add_memory(self, text: str, metadata: Dict[str, Any] = None) -> None: """添加记忆到向量存储""" # 添加文本到向量存储 self.vectorstore.add_texts([text], metadatas=[metadata] if metadata else None) def retrieve_relevant_memories(self, query: str, k: int = 5) -> List[Dict[str, Any]]: """检索与查询相关的记忆""" # 检索相关文档 docs = self.vectorstore.similarity_search(query, k=k) # 提取记忆内容和元数据 memories = [] for doc in docs: memories.append({ "content": doc.page_content, "metadata": doc.metadata }) return memories def delete_memory(self, memory_id: str) -> None: """删除指定ID的记忆""" self.vectorstore.delete([memory_id]) def clear_all_memories(self) -> None: """清除所有记忆""" # 删除所有文档 all_ids = self.vectorstore.get()["ids"] if all_ids: self.vectorstore.delete(all_ids)
PineconeMemoryStore实现了基于Pinecone向量数据库的记忆存储功能。Pinecone是一个云托管的向量数据库,提供了高性能的向量检索能力和自动扩展功能。这种实现适合需要大规模部署和高性能检索的企业级应用场景。
IV. 会话管理与状态恢复
4.1 会话标识与管理
在跨会话记忆恢复中,会话标识是关键的一环。每个会话需要有一个唯一的标识符,以便系统能够区分不同的会话并正确恢复其状态。
LangChain提供了多种会话标识的生成和管理方式:
import uuidfrom typing import Dict, Any, Optionalclass SessionManager: """会话管理器,负责生成和管理会话ID""" def __init__(self, session_id: str = None): """初始化会话管理器""" self.session_id = session_id or self._generate_session_id() def _generate_session_id(self) -> str: """生成唯一的会话ID""" return str(uuid.uuid4()) def get_session_id(self) -> str: """获取当前会话ID""" return self.session_id def renew_session_id(self) -> str: """更新会话ID""" self.session_id = self._generate_session_id() return self.session_id def save_session_state(self, state: Dict[str, Any], storage: Any) -> None: """将会话状态保存到存储""" # 根据存储类型选择不同的保存方法 if isinstance(storage, FileChatMessageHistory): # 保存到文件 state_key = f"session_state:{self.session_id}" file_path = f"{storage.file_path}.{state_key}" try: with open(file_path, "w") as f: json.dump(state, f) except Exception as e: print(f"Error saving session state to file: {e}") elif isinstance(storage, SQLiteChatMessageHistory): # 保存到SQLite conn = sqlite3.connect(storage.db_path) cursor = conn.cursor() # 创建会话状态表(如果不存在) cursor.execute(""" CREATE TABLE IF NOT EXISTS session_states ( session_id TEXT PRIMARY KEY, state TEXT, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ) """) # 保存状态 state_json = json.dumps(state) cursor.execute( "INSERT OR REPLACE INTO session_states (session_id, state) VALUES (?, ?)", (self.session_id, state_json) ) conn.commit() conn.close() elif isinstance(storage, RedisChatMessageHistory): # 保存到Redis key = f"session_state:{self.session_id}" try: storage.redis_client.set(key, json.dumps(state)) except Exception as e: print(f"Error saving session state to Redis: {e}") def load_session_state(self, storage: Any) -> Optional[Dict[str, Any]]: """从存储加载会话状态""" # 根据存储类型选择不同的加载方法 if isinstance(storage, FileChatMessageHistory): # 从文件加载 state_key = f"session_state:{self.session_id}" file_path = f"{storage.file_path}.{state_key}" if os.path.exists(file_path): try: with open(file_path, "r") as f: return json.load(f) except Exception as e: print(f"Error loading session state from file: {e}") return None elif isinstance(storage, SQLiteChatMessageHistory): # 从SQLite加载 conn = sqlite3.connect(storage.db_path) cursor = conn.cursor() cursor.execute( "SELECT state FROM session_states WHERE session_id = ?", (self.session_id,) ) row = cursor.fetchone() conn.close() if row: return json.loads(row[0]) return None elif isinstance(storage, RedisChatMessageHistory): # 从Redis加载 key = f"session_state:{self.session_id}" state_json = storage.redis_client.get(key) if state_json: return json.loads(state_json) return None return None
SessionManager负责生成和管理会话ID,并提供会话状态的保存和加载功能。它支持多种存储后端,包括文件、SQLite和Redis,使得会话状态可以在不同的环境中持久化存储和恢复。
4.2 状态序列化与反序列化
为了实现会话状态的持久化存储和恢复,需要将会话状态序列化为某种格式(如JSON),并在需要时反序列化回原始状态。
LangChain提供了多种状态序列化和反序列化的方法:
import jsonfrom typing import Dict, Any, Optionalclass StateSerializer: """状态序列化器,负责会话状态的序列化和反序列化""" @staticmethod def serialize_state(state: Dict[str, Any]) -> str: """序列化会话状态""" try: return json.dumps(state) except Exception as e: print(f"Error serializing state: {e}") return "" @staticmethod def deserialize_state(state_str: str) -> Optional[Dict[str, Any]]: """反序列化会话状态""" try: return json.loads(state_str) except Exception as e: print(f"Error deserializing state: {e}") return None @staticmethod def serialize_memory(memory: BaseMemory) -> Dict[str, Any]: """序列化记忆对象""" # 根据记忆类型执行不同的序列化逻辑 if isinstance(memory, SimpleMemory): return { "type": "SimpleMemory", "chat_memory": memory.chat_memory } elif isinstance(memory, ChatMessageHistory): return { "type": "ChatMessageHistory", "messages": [{"type": type(msg).__name__, "content": msg.content} for msg in memory.messages] } elif isinstance(memory, VectorStoreRetrieverMemory): # 向量存储记忆的序列化更复杂,可能需要保存向量数据库连接信息 return { "type": "VectorStoreRetrieverMemory", "k": memory.k, "memory_key": memory.memory_key, # 注意:这里不保存完整的向量存储,只保存配置信息 "vectorstore_config": { "type": type(memory.vectorstore).__name__, "persist_directory": getattr(memory.vectorstore, "persist_directory", None) } } # 默认情况,尝试获取记忆的字典表示 try: return {"type": type(memory).__name__, "state": vars(memory)} except Exception: return {"type": type(memory).__name__} @staticmethod def deserialize_memory(state: Dict[str, Any], embedding_function: Optional[Embeddings] = None) -> Optional[BaseMemory]: """反序列化记忆对象""" memory_type = state.get("type") if memory_type == "SimpleMemory": memory = SimpleMemory() memory.chat_memory = state.get("chat_memory", []) return memory elif memory_type == "ChatMessageHistory": memory = ChatMessageHistory() messages = state.get("messages", []) for msg in messages: msg_type = msg.get("type") content = msg.get("content", "") if msg_type == "HumanMessage": memory.add_user_message(content) elif msg_type == "AIMessage": memory.add_ai_message(content) return memory elif memory_type == "VectorStoreRetrieverMemory": if not embedding_function: print("Embedding function is required to deserialize VectorStoreRetrieverMemory") return None vectorstore_config = state.get("vectorstore_config", {}) vectorstore_type = vectorstore_config.get("type") persist_directory = vectorstore_config.get("persist_directory") if vectorstore_type == "Chroma" and persist_directory: # 从持久化目录加载Chroma向量存储 vectorstore = Chroma( embedding_function=embedding_function, persist_directory=persist_directory ) memory = VectorStoreRetrieverMemory( vectorstore=vectorstore, embeddings=embedding_function, k=state.get("k", 5), memory_key=state.get("memory_key", "history") ) return memory return None
StateSerializer提供了会话状态和记忆对象的序列化和反序列化功能。它支持多种记忆类型,包括SimpleMemory、ChatMessageHistory和VectorStoreRetrieverMemory。对于复杂的记忆类型(如向量存储记忆),它只保存配置信息,而不是完整的向量存储,以避免数据冗余和安全风险。
4.3 会话恢复流程
会话恢复是跨会话记忆恢复技术的核心流程,它确保系统能够在会话中断后正确恢复之前的状态和记忆。
class ConversationReverter: """会话恢复器,负责会话的恢复和继续""" def __init__(self, session_manager: SessionManager, memory_store: Any, embedding_function: Optional[Embeddings] = None): """初始化会话恢复器""" self.session_manager = session_manager self.memory_store = memory_store self.embedding_function = embedding_function def create_new_session(self) -> str: """创建新会话""" # 生成新的会话ID session_id = self.session_manager.renew_session_id() # 初始化新会话的记忆 if isinstance(self.memory_store, ChatMessageHistory): # 清除现有消息 self.memory_store.clear() elif isinstance(self.memory_store, VectorStoreRetrieverMemory): # 不需要清除,向量存储可以被多个会话共享 pass return session_id def restore_session(self, session_id: str) -> bool: """恢复现有会话""" # 设置会话ID self.session_manager.session_id = session_id # 从存储加载会话状态 session_state = self.session_manager.load_session_state(self.memory_store) if not session_state: print(f"Session {session_id} not found or invalid state") return False # 恢复记忆 if "memory" in session_state: memory_state = session_state["memory"] if isinstance(self.memory_store, BaseMemory): # 如果记忆存储是BaseMemory的实例,尝试恢复其状态 new_memory = StateSerializer.deserialize_memory( memory_state, self.embedding_function ) if new_memory: # 替换当前记忆 self.memory_store = new_memory return True return False def continue_conversation(self, session_id: str, user_input: str, llm_chain: Any) -> str: """继续现有会话""" # 恢复会话 if not self.restore_session(session_id): # 如果恢复失败,创建新会话 self.create_new_session() # 准备输入 inputs = {"input": user_input} # 如果有记忆,加载记忆变量 if isinstance(self.memory_store, BaseMemory): memory_variables = self.memory_store.load_memory_variables(inputs) inputs.update(memory_variables) # 运行LLM链 output = llm_chain.run(inputs) # 保存上下文 if isinstance(self.memory_store, BaseMemory): self.memory_store.save_context(inputs, {"output": output}) # 保存会话状态 session_state = { "memory": StateSerializer.serialize_memory(self.memory_store), "last_updated": time.time() } self.session_manager.save_session_state(session_state, self.memory_store) return output
ConversationReverter实现了会话的创建、恢复和继续功能。它与SessionManager和记忆存储协同工作,确保会话状态和记忆能够正确地保存和恢复。通过调用continue_conversation方法,可以无缝地继续之前的会话,就像对话从未中断过一样。
V. 记忆检索与相似度计算
5.1 向量嵌入技术
向量嵌入技术将文本转化为高维向量空间中的点,使得语义相近的文本在向量空间中距离较近,从而为记忆检索提供基础。在LangChain中,向量嵌入的实现依赖于各类嵌入模型,以OpenAIEmbeddings
为例:
from langchain.embeddings.openai import OpenAIEmbeddingsimport os# 设置OpenAI API密钥os.environ["OPENAI_API_KEY"] = "your_api_key"# 初始化OpenAIEmbeddings实例embeddings = OpenAIEmbeddings()# 将文本转换为嵌入向量text = "如何提高代码的可读性?"embedding = embeddings.embed_query(text)print(embedding)
上述代码通过OpenAIEmbeddings
类将文本转化为嵌入向量。OpenAIEmbeddings
内部通过调用OpenAI的API来获取嵌入结果,在__init__
方法中,会读取环境变量中的API密钥,并设置默认的嵌入模型参数:
class OpenAIEmbeddings(BaseEmbeddings): """基于OpenAI的嵌入类""" openai_api_base: Optional[str] = None openai_api_type: Optional[str] = None openai_api_version: Optional[str] = None openai_api_key: Optional[str] = None model: str = "text-embedding-ada-002" # 其他参数和方法... def __init__(self, **kwargs: Any): """初始化方法""" super().__init__(**kwargs) # 从环境变量获取API密钥 self.openai_api_key = self.openai_api_key or get_from_dict_or_env( kwargs, "openai_api_key", "OPENAI_API_KEY" ) self.openai_api_base = self.openai_api_base or get_from_dict_or_env( kwargs, "openai_api_base", "OPENAI_API_BASE", default="https://api.openai.com/v1" ) self.openai_api_type = self.openai_api_type or get_from_dict_or_env( kwargs, "openai_api_type", "OPENAI_API_TYPE", default="open_ai" ) self.openai_api_version = self.openai_api_version or get_from_dict_or_env( kwargs, "openai_api_version", "OPENAI_API_VERSION" )
除了OpenAI的嵌入模型,LangChain还支持其他开源模型,如HuggingFaceEmbeddings
,通过加载本地或远程的模型来生成嵌入向量:
from langchain.embeddings import HuggingFaceEmbeddings# 选择要使用的Hugging Face模型model_name = "sentence-transformers/all-MiniLM-L6-v2"# 初始化HuggingFaceEmbeddings实例embeddings = HuggingFaceEmbeddings(model_name=model_name)text = "自然语言处理有哪些应用场景?"embedding = embeddings.embed_query(text)
HuggingFaceEmbeddings
在初始化时,会根据模型名称加载相应的模型,并设置模型的相关参数,如设备类型等:
class HuggingFaceEmbeddings(BaseEmbeddings): """基于Hugging Face的嵌入类""" model_name: str = "sentence-transformers/all-MiniLM-L6-v2" model_kwargs: Dict[str, Any] = {} encode_kwargs: Dict[str, Any] = {"normalize_embeddings": False} cache_folder: Optional[str] = None # 其他参数和方法... def __init__(self, **kwargs: Any): super().__init__(**kwargs) try: from sentence_transformers import SentenceTransformer except ImportError: raise ValueError( "Could not import sentence_transformers python package. " "Please install it with `pip install sentence-transformers`." ) self.model = SentenceTransformer( self.model_name, cache_folder=self.cache_folder, **self.model_kwargs )
5.2 相似度计算方法
在获取文本的向量表示后,需要通过相似度计算来检索相关记忆。LangChain中常用的相似度计算方法有余弦相似度、欧氏距离等。以余弦相似度为例,在向量存储的检索过程中经常会用到:
import numpy as np# 两个嵌入向量示例vector1 = np.array([0.1, 0.2, 0.3])vector2 = np.array([0.4, 0.5, 0.6])# 计算余弦相似度cosine_similarity = np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))print(cosine_similarity)
在向量存储类,如Chroma
中,检索时会默认使用余弦相似度来衡量向量之间的相似程度:
from langchain.vectorstores import Chromafrom langchain.embeddings import OpenAIEmbeddings# 初始化嵌入模型embeddings = OpenAIEmbeddings()# 示例文本texts = ["苹果是一种水果", "香蕉也是水果", "汽车是交通工具"]# 创建Chroma向量存储vectorstore = Chroma.from_texts(texts, embeddings)# 进行检索,默认使用余弦相似度query = "寻找水果相关内容"results = vectorstore.similarity_search(query)for result in results: print(result.page_content)
Chroma
的similarity_search
方法内部会将查询文本转化为向量,然后与存储的向量进行余弦相似度计算,并按照相似度得分排序返回结果:
class Chroma(VectorStore): # 其他方法和属性... def similarity_search( self, query: str, k: int = 4, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """通过余弦相似度搜索相关文档""" embedding = self._embedding_function.embed_query(query) docs = self.similarity_search_by_vector( embedding, k=k, filter=filter, **kwargs ) return docs
除了余弦相似度,欧氏距离也可用于衡量向量之间的差异,在一些场景下,欧氏距离能更好地反映向量在空间中的实际距离:
import numpy as npvector1 = np.array([1, 2, 3])vector2 = np.array([4, 5, 6])# 计算欧氏距离euclidean_distance = np.linalg.norm(vector1 - vector2)print(euclidean_distance)
在LangChain中,部分向量存储的检索方法也支持通过参数设置选择不同的距离度量方式,以适应不同的应用需求。
5.3 检索策略优化
为了提高记忆检索的效率和准确性,LangChain采用了多种检索策略优化方法。其中,分级检索是一种重要策略,它先进行粗粒度的检索,筛选出大致相关的记忆范围,再进行细粒度的精确检索。
from langchain.vectorstores import Chromafrom langchain.embeddings import OpenAIEmbeddingsimport numpy as np# 初始化嵌入模型和向量存储embeddings = OpenAIEmbeddings()texts = ["数学基础包括代数", "几何是数学的重要分支", "代数中的方程求解", "文学作品有诗歌", "小说也是文学形式"]vectorstore = Chroma.from_texts(texts, embeddings)# 分级检索示例query = "寻找代数相关内容"# 粗粒度检索,获取初步相关文档coarse_results = vectorstore.similarity_search(query, k=5)coarse_texts = [result.page_content for result in coarse_results]# 将初步结果再次转化为向量,进行细粒度检索coarse_embeddings = np.array([embeddings.embed_query(t) for t in coarse_texts])query_embedding = embeddings.embed_query(query)# 计算细粒度相似度fine_similarities = np.dot(coarse_embeddings, query_embedding) / ( np.linalg.norm(coarse_embeddings, axis=1) * np.linalg.norm(query_embedding))# 根据相似度排序,获取最终结果sorted_indices = np.argsort(fine_similarities)[::-1]fine_results = [coarse_results[i] for i in sorted_indices]for result in fine_results: print(result.page_content)
此外,LangChain还支持基于元数据的过滤检索,通过在存储记忆时添加元数据,在检索时根据元数据条件进行筛选,缩小检索范围,提高检索效率。
from langchain.vectorstores import Chromafrom langchain.embeddings import OpenAIEmbeddings# 初始化嵌入模型embeddings = OpenAIEmbeddings()texts = ["苹果是红色的水果", "香蕉是黄色的水果", "汽车是蓝色的交通工具"]metadatas = [{"category": "水果", "color": "red"}, {"category": "水果", "color": "yellow"}, {"category": "交通工具", "color": "blue"}]# 创建Chroma向量存储,并添加元数据vectorstore = Chroma.from_texts(texts, embeddings, metadatas=metadatas)# 基于元数据过滤检索query = "寻找黄色的东西"results = vectorstore.similarity_search(query, filter={"color": "yellow"})for result in results: print(result.page_content)
通过这些检索策略优化,LangChain能够在大量记忆中快速准确地找到与当前问题相关的信息,为跨会话记忆恢复提供有力支持 。
VI. 记忆更新与管理机制
6.1 增量式记忆更新
在对话过程中,记忆需要随着新信息的加入而不断更新。LangChain采用增量式记忆更新方式,避免每次都重新处理全部记忆,提高效率。以ChatMessageHistory
为例,当有新的用户消息和AI回复时,会将新消息追加到记忆列表中:
from langchain.schema import HumanMessage, AIMessagefrom langchain.memory import ChatMessageHistory# 初始化聊天消息历史history = ChatMessageHistory()# 添加用户消息user_message = "今天天气怎么样?"history.add_user_message(user_message)# 添加AI回复ai_message = "今天天气晴朗,适合外出。"history.add_ai_message(ai_message)# 查看更新后的消息历史messages = history.messagesfor msg in messages: print(f"{type(msg).__name__}: {msg.content}")
ChatMessageHistory
的add_user_message
和add_ai_message
方法实现了增量添加功能:
class ChatMessageHistory: def __init__(self): self.messages = [] def add_user_message(self, message: str) -> None: """添加用户消息""" self.messages.append(HumanMessage(content=message)) def add_ai_message(self, message: str) -> None: """添加AI消息""" self.messages.append(AIMessage(content=message))
对于向量存储类型的记忆,如VectorStoreRetrieverMemory
,在保存新的上下文时,会将新的文本添加到向量数据库中,并生成对应的向量:
from langchain.vectorstores import Chromafrom langchain.embeddings import OpenAIEmbeddingsfrom langchain.memory import VectorStoreRetrieverMemory# 初始化嵌入模型和向量存储embeddings = OpenAIEmbeddings()vectorstore = Chroma(embedding_function=embeddings)# 初始化向量存储记忆memory = VectorStoreRetrieverMemory(vectorstore=vectorstore, embeddings=embeddings)# 新的用户输入和AI回复user_input = "附近有什么好吃的餐厅?"ai_output = "附近的美食街有很多特色餐厅。"# 保存新的上下文memory.save_context({"input": user_input}, {"output": ai_output})
VectorStoreRetrieverMemory
的save_context
方法会调用向量存储的add_texts
方法来添加新文本:
class VectorStoreRetrieverMemory(BaseMemory): def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None: """保存当前上下文到向量存储""" user_message = inputs.get("input", "") ai_message = outputs.get("output", "") full_message = f"User: {user_message}\nAI: {ai_message}" self.vectorstore.add_texts([full_message])
6.2 记忆过期与清除策略
随着对话的持续进行,记忆会不断增长,为了避免存储过多无用信息,LangChain制定了记忆过期与清除策略。一种常见的方式是基于时间的过期策略,为每条记忆设置一个过期时间,当超过该时间后,自动清除该记忆。
import timefrom typing import Dict, Anyclass TimedMemory: def __init__(self, expiration_time: int): self.memory = {} self.expiration_time = expiration_time # 过期时间(秒) def add_memory(self, key: str, value: Any) -> None: """添加记忆,并记录时间""" self.memory[key] = {"value": value, "timestamp": time.time()} def get_memory(self, key: str) -> Any: """获取记忆,检查是否过期""" memory_item = self.memory.get(key) if memory_item: current_time = time.time() if current_time - memory_item["timestamp"] > self.expiration_time: del self.memory[key] return None return memory_item["value"] return None def clear_expired_memory(self) -> None: """清除所有过期记忆""" current_time = time.time() keys_to_delete = [key for key, item in self.memory.items() if current_time - item["timestamp"] > self.expiration_time] for key in keys_to_delete: del self.memory[key]
在实际应用中,VectorStoreRetrieverMemory
等记忆类也可以通过扩展实现类似的过期清除功能。另外,还可以基于记忆的使用频率进行清除,对于长时间未被使用的记忆进行删除,释放存储空间:
class UsageBasedMemory: def __init__(self, min_usage_count: int): self.memory = {} self.min_usage_count = min_usage_count # 最小使用次数 def add_memory(self, key: str, value: Any) -> None: """添加记忆,初始化使用次数为1""" self.memory[key] = {"value": value, "usage_count": 1} def get_memory(self, key: str) -> Any: """获取记忆,增加使用次数""" memory_item = self.memory.get(key) if memory_item: memory_item["usage_count"] += 1 return memory_item["value"] return None def clear_underused_memory(self) -> None: """清除使用次数低于阈值的记忆""" keys_to_delete = [key for key, item in self.memory.items() if item["usage_count"] < self.min_usage_count] for key in keys_to_delete: del self.memory[key]
通过这些记忆过期与清除策略,LangChain能够有效地管理记忆容量,保持记忆的有效性和相关性。
6.3 记忆整合与优化
当记忆不断更新后,为了提高记忆的质量和检索效率,需要进行记忆整合与优化。在向量存储中,随着新向量的不断添加,可能会出现向量分布不均匀等问题,影响检索效果。此时可以通过重新索引来优化向量存储:
from langchain.vectorstores import Chromafrom langchain.embeddings import OpenAIEmbeddings# 初始化嵌入模型和向量存储embeddings = OpenAIEmbeddings()vectorstore = Chroma(embedding_function=embeddings)# 多次添加文本,模拟记忆更新texts = ["第一次添加的内容", "第二次添加的内容", "第三次添加的内容"]for text in texts: vectorstore.add_texts([text])# 重新索引优化向量存储vectorstore.recreate_index()
Chroma
的recreate_index
方法会重新构建向量索引,调整向量的存储结构,提升检索性能。此外,对于文本类型的记忆,如ChatMessageHistory
,可以通过总结归纳的方式对长对话进行整合,提取关键信息,减少记忆冗余:
from langchain.memory import ChatMessageHistoryfrom langchain.chains.summarize import load_summarize_chainfrom langchain.llms import OpenAI# 初始化聊天消息历史并添加大量消息history = ChatMessageHistory()messages = ["消息1", "消息2", "消息3", "消息4", "消息5"]for msg in messages: history.add_user_message(msg) history.add_ai_message(f"回复{msg}")# 将消息历史转换为文档格式docs = [{"page_content": msg.content} for msg in history.messages]# 初始化语言模型和总结链llm = OpenAI()chain = load_summarize_chain(llm, chain_type="stuff")# 对消息进行总结summary = chain.run(docs)# 用总结后的内容更新记忆(示例,实际可根据需求调整)history.messages = []history.add_user_message("对话总结")history.add_ai_message(summary)
VII. 上下文感知与动态记忆应用
7.1 上下文理解机制
在跨会话记忆恢复中,准确理解上下文是关键。LangChain通过结合历史记忆与当前输入,构建出完整的上下文信息,从而使模型更好地理解用户意图。在代码实现层面,以RetrievalQA
链为例,它会将检索到的记忆与当前问题组合成新的输入,传递给语言模型:
from langchain.document_loaders import TextLoaderfrom langchain.text_splitter import CharacterTextSplitterfrom langchain.embeddings import OpenAIEmbeddingsfrom langchain.vectorstores import Chromafrom langchain.chains import RetrievalQAfrom langchain.llms import OpenAI# 加载文档loader = TextLoader('example.txt')documents = loader.load()# 文本分割text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)texts = text_splitter.split_documents(documents)# 初始化嵌入模型和向量存储embeddings = OpenAIEmbeddings()vectorstore = Chroma.from_documents(texts, embeddings)# 创建检索器retriever = vectorstore.as_retriever()# 初始化语言模型和问答链llm = OpenAI()qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)# 提出问题,结合上下文获取答案question = "文档中提到的核心技术是什么?"result = qa.run(question)print(result)
在RetrievalQA
内部,_call
方法负责整合上下文:
class RetrievalQA(BaseQA): # ... def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: question = inputs[self.question_key] # 从向量存储中检索相关文档 docs = self.retriever.get_relevant_documents(question) # 将问题和相关文档组合成新的输入 new_inputs = { self.question_key: question, self.input_key: [doc.page_content for doc in docs] } # 运行问答链 answer = self.combine_documents_chain.run(new_inputs) return {self.output_key: answer}
此外,LangChain还支持通过自定义提示模板来引导模型理解上下文。例如,在构建问答链时,可以设置提示模板强调上下文的重要性:
from langchain import PromptTemplate, LLMChainfrom langchain.llms import OpenAIprompt = PromptTemplate( input_variables=["context", "question"], template="根据以下上下文信息,回答问题:\n上下文:{context}\n问题:{question}")llm = OpenAI()llm_chain = LLMChain(prompt=prompt, llm=llm)context = "苹果是一种富含维生素的水果"question = "苹果有什么营养?"result = llm_chain.run(context=context, question=question)print(result)
7.2 动态记忆调整
随着对话的推进,记忆的重要性和相关性会发生变化,LangChain具备动态记忆调整能力。在向量存储记忆中,当新的记忆添加后,会重新评估所有记忆与后续问题的相关性权重。例如,VectorStoreRetrieverMemory
在每次检索时,会根据新的查询更新记忆的“活跃度”标记:
class VectorStoreRetrieverMemory(BaseMemory): def __init__(self, vectorstore, embeddings, k=5, memory_key="history"): self.vectorstore = vectorstore self.embeddings = embeddings self.k = k self.memory_key = memory_key self.memory_activity = {} # 记录记忆活跃度 def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: query = inputs.get("input", "") if not query: return {self.memory_key: []} query_embedding = self.embeddings.embed_query(query) docs = self.retriever.get_relevant_documents(query) history = [doc.page_content for doc in docs] # 更新记忆活跃度 for doc in docs: doc_id = doc.metadata.get('id') self.memory_activity[doc_id] = self.memory_activity.get(doc_id, 0) + 1 return {self.memory_key: history}
对于基于时间的记忆,也会动态调整过期时间。如果某段记忆在近期频繁被使用,会适当延长其过期时间;反之,则缩短过期时间:
class DynamicTimedMemory: def __init__(self, base_expiration_time: int, usage_factor: float): self.memory = {} self.base_expiration_time = base_expiration_time self.usage_factor = usage_factor # 使用频率影响因子 def add_memory(self, key: str, value: Any) -> None: self.memory[key] = {"value": value, "timestamp": time.time(), "usage_count": 1} def get_memory(self, key: str) -> Any: memory_item = self.memory.get(key) if memory_item: current_time = time.time() # 根据使用频率调整过期时间 adjusted_time = self.base_expiration_time * (1 + (memory_item["usage_count"] - 1) * self.usage_factor) if current_time - memory_item["timestamp"] > adjusted_time: del self.memory[key] return None memory_item["usage_count"] += 1 return memory_item["value"] return None
7.3 多轮对话中的记忆应用
在多轮对话场景下,LangChain充分利用记忆来保持对话连贯性。每一轮对话的输入和输出都会被保存到记忆中,为下一轮对话提供参考。以聊天机器人为例:
from langchain.memory import ConversationBufferMemoryfrom langchain.chat_models import ChatOpenAIfrom langchain.chains import ConversationChain# 初始化记忆和语言模型memory = ConversationBufferMemory()llm = ChatOpenAI()# 创建对话链conversation = ConversationChain( llm=llm, memory = memory, verbose=True)# 第一轮对话response_1 = conversation.predict(input="你好")print(response_1)# 第二轮对话,记忆会传递上一轮信息response_2 = conversation.predict(input="能推荐一部电影吗?")print(response_2)
ConversationBufferMemory
会将每一轮对话的内容以字符串形式缓存起来,并在生成新回复时,将缓存内容添加到输入中。其load_memory_variables
方法实现了这一过程:
class ConversationBufferMemory(BaseMemory): def __init__(self, memory_key="history", input_key="input", output_key="output"): self.memory_key = memory_key self.input_key = input_key self.output_key = output_key self.buffer = "" def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return {self.memory_key: self.buffer} def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None: input_str = self._get_input_str(inputs) output_str = self._get_output_str(outputs) self.buffer += f"{input_str}{output_str}"
通过这种方式,模型在多轮对话中能够参考之前的对话内容,理解用户的连续意图,实现更加自然流畅的交流 。
VIII. 多模态记忆处理
8.1 多模态数据存储
随着应用场景的拓展,LangChain需要处理图像、音频等多模态数据的记忆。在存储层面,对于图像数据,可通过特征提取将图像转化为向量后存储在向量数据库中。以使用OpenAI
的图像嵌入模型为例:
import openaifrom langchain.vectorstores import Chromafrom langchain.embeddings.openai import OpenAIEmbeddingsopenai.api_key = "your_api_key"# 图像文件路径image_path = "example.jpg"# 读取图像文件with open(image_path, "rb") as image_file: image_data = image_file.read()# 使用OpenAI获取图像嵌入向量response = openai.Image.create_embedding( image=image_data, model="image-embedding-001")image_embedding = response['data'][0]['embedding']# 初始化嵌入模型和向量存储embeddings = OpenAIEmbeddings()vectorstore = Chroma(embedding_function=embeddings)# 将图像嵌入向量和相关元数据存储vectorstore.add_vectors([image_embedding], metadatas=[{"source": "image", "path": image_path}])
对于音频数据,可先将音频转换为文本(如通过语音识别技术),再将文本转化为向量进行存储。例如,结合Whisper
进行语音识别和LangChain
进行向量存储:
import whisperfrom langchain.vectorstores import Chromafrom langchain.embeddings import OpenAIEmbeddings# 加载Whisper模型model = whisper.load_model("base")# 音频文件路径audio_path = "example_audio.wav"# 进行语音识别result = model.transcribe(audio_path)transcribed_text = result["text"]# 初始化嵌入模型和向量存储embeddings = OpenAIEmbeddings()vectorstore = Chroma(embedding_function=embeddings)# 存储文本对应的向量vectorstore.add_texts([transcribed_text], metadatas=[{"source": "audio", "path": audio_path}])
8.2 多模态记忆检索
在多模态记忆检索时,需要综合考虑不同模态数据的特征。当检索与图像相关的记忆时,可通过计算图像向量与查询向量的相似度来获取结果。例如,在Chroma
向量存储中检索相似图像:
from langchain.vectorstores import Chromafrom langchain.embeddings.openai import OpenAIEmbeddings# 初始化嵌入模型和向量存储(假设已存储图像向量)embeddings = OpenAIEmbeddings()vectorstore = Chroma(embedding_function=embeddings)# 新图像文件路径,用于生成查询向量new_image_path = "new_example.jpg"with open(new_image_path, "rb") as image_file: new_image_data = image_file.read()response = openai.Image.create_embedding( image=new_image_data, model="image-embedding-001")new_image_embedding = response['data'][0]['embedding']# 检索相似图像记忆results = vectorstore.similarity_search_by_vector(new_image_embedding)for result in results: print(result.metadata)
对于混合模态的检索,如同时包含文本和图像的查询,LangChain可以将不同模态的嵌入向量进行融合处理。例如,将文本向量和图像向量拼接或加权求和后,再进行检索:
import numpy as npfrom langchain.vectorstores import Chromafrom langchain.embeddings import OpenAIEmbeddings# 初始化嵌入模型和向量存储embeddings = OpenAIEmbeddings()vectorstore = Chroma(embedding_function=embeddings)# 文本查询text_query = "寻找与猫相关的内容"text_embedding = embeddings.embed_query(text_query)# 图像查询(假设已获取图像嵌入向量)image_embedding = np.array([0.1, 0.2, 0.3])# 融合向量combined_embedding = np.concatenate((text_embedding, image_embedding))# 进行检索results = vectorstore.similarity_search_by_vector(combined_embedding)for result in results: print(result.page_content)
8.3 多模态记忆融合与应用
多模态记忆的融合旨在将不同模态的信息整合,为模型提供更丰富的上下文。在实际应用中,当生成回复时,可以将多模态记忆与文本记忆结合,输入到语言模型中。例如,在一个图文问答系统中:
from langchain.document_loaders import TextLoaderfrom langchain.text_splitter import CharacterTextSplitterfrom langchain.embeddings import OpenAIEmbeddingsfrom langchain.vectorstores import Chromafrom langchain.chains import RetrievalQAfrom langchain.llms import OpenAI# 加载文本文档text_loader = TextLoader('text_example.txt')text_documents = text_loader.load()text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)texts = text_splitter.split_documents(text_documents)# 假设已存储图像向量到同一向量存储中# 初始化嵌入模型和向量存储embeddings = OpenAIEmbeddings()vectorstore = Chroma.from_documents(texts, embeddings)# 创建问答链llm = OpenAI()qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever())# 结合图像和文本的查询query = "图像中的场景在文档中有提到吗?"result = qa.run(query)print(result)
在内部实现上,RetrievalQA
链会将检索到的多模态记忆与文本记忆统一处理,通过提示模板等方式引导语言模型理解融合后的信息,从而生成更准确的回复 。
IX. 分布式环境下的记忆管理
9.1 分布式存储架构
在分布式环境中,为了实现记忆的有效管理,LangChain采用分布式存储架构。以基于Redis的分布式存储为例,多个节点可以共享Redis中的记忆数据:
import redisfrom langchain.memory import RedisChatMessageHistory# 初始化Redis连接redis_client = redis.Redis(host='localhost', port=6379, db=0)# 节点1使用的聊天消息历史history_1 = RedisChatMessageHistory(redis_client=redis_client, session_id="session_1")history_1.add_user_message("节点1的用户消息")history_1.add_ai_message("节点1的AI回复")# 节点2使用的聊天消息历史history_2 = RedisChatMessageHistory(redis_client=redis_client, session_id="session_1")# 节点2可以获取到节点1保存的消息messages = history_2.messagesfor msg in messages: print(f"{type(msg).__name__}: {msg.content}")
对于向量存储,也可以采用分布式向量数据库,如Pinecone。多个应用实例可以连接到同一个Pinecone索引,实现记忆的共享和同步:
import pineconefrom langchain.vectorstores import Pineconefrom langchain.embeddings import OpenAIEmbeddings# 初始化Pineconepinecone.init(api_key="your_api_key", environment="your_env")index_name = "your_index_name"embeddings = OpenAIEmbeddings()# 节点1向向量存储添加记忆vectorstore_1 = Pinecone.from_existing_index(index_name, embeddings)vectorstore_1.add_texts(["节点1添加的文本"])# 节点2连接到同一索引,获取记忆vectorstore_2 = Pinecone.from_existing_index(index_name, embeddings)results = vectorstore_2.similarity_search("查询")for result in results: print(result.page_content)
9.2 数据同步与一致性
在分布式环境中,数据同步和一致性是关键问题。对于基于文件的记忆存储,可通过分布式文件系统(如Ceph)实现数据同步。在LangChain中,当使用FileChatMessageHistory
时,只要多个节点都能访问分布式文件系统的路径,就能实现记忆同步:
from langchain.memory import FileChatMessageHistory# 假设分布式文件系统路径file_path = "/distributed_storage/chat_history.json"# 节点1保存消息history_1 = FileChatMessageHistory(file_path)history_1.add_user_message("节点1消息")history_1.add_ai_message("节点1回复")# 节点2获取同步后的消息history_2 = FileChatMessageHistory(file_path)messages = history_2.messagesfor msg in messages: print(f"{type(msg).__name__}: {msg.content}")
对于数据库存储,如SQLite,可以通过数据库复制技术实现数据同步。而对于Redis、Pinecone等分布式数据库,自身具备数据同步和一致性保障机制。例如,Redis通过主从复制实现数据同步,Pinecone通过内部的集群管理机制保证数据在多个节点间的一致性。
9.3 负载均衡与故障容错
为了提高系统的可用性和性能,LangChain在分布式环境下采用负载均衡和故障容错策略。在记忆存储的访问层面,可以使用负载均衡器(如Nginx)来分配对Redis、Pinecone等存储的请求。以Nginx配置为例,实现对Redis的负载均衡:
upstream redis_servers { server redis_server_1:6379; server redis_server_2:6379; # 可以添加更多Redis服务器 least_conn; # 采用最少连接数算法}server { listen 80; server_name your_domain; location /redis { proxy_pass http://redis_servers; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; }}
在故障容错方面,当某个存储节点出现故障时,LangChain可以通过自动切换到备用节点来保证服务的连续性。以使用Redis集群为例,redis-py
库提供了对Redis集群故障转移的支持,在LangChain的记忆模块中可以进行如下适配:
import redisclusterfrom langchain.memory import RedisChatMessageHistory# 初始化Redis集群节点列表redis_nodes = [ {"host": "redis_node_1", "port": 6379}, {"host": "redis_node_2", "port": 6379}, {"host": "redis_node_3", "port": 6379}]# 创建Redis集群客户端rc = rediscluster.RedisCluster(startup_nodes=redis_nodes, decode_responses=True)# 使用Redis集群客户端初始化聊天消息历史history = RedisChatMessageHistory(redis_client=rc, session_id="my_session")history.add_user_message("测试消息")
当Redis集群中的某个节点发生故障时,redis-py
库会自动检测并将请求重定向到其他可用节点,确保记忆的读写操作不受影响。
对于向量存储,以Pinecone为例,其自身具备高可用性和自动故障恢复能力。在LangChain中使用Pinecone时,即使某个Pinecone的服务节点出现问题,应用程序也无需手动干预,因为Pinecone的SDK会自动处理连接和请求的重路由:
import pineconefrom langchain.vectorstores import Pineconefrom langchain.embeddings import OpenAIEmbeddingspinecone.init(api_key="your_api_key", environment="your_env")index_name = "your_index_name"embeddings = OpenAIEmbeddings()vectorstore = Pinecone.from_existing_index(index_name, embeddings)query = "查找相关内容"results = vectorstore.similarity_search(query)
此外,在应用层可以通过心跳检测和重试机制进一步增强容错能力。例如,创建一个心跳检测函数,定期检查记忆存储服务的可用性:
import timeimport redisdef check_redis_availability(redis_host, redis_port): try: r = redis.Redis(host=redis_host, port=redis_port) r.ping() return True except redis.ConnectionError: return Falsewhile True: if not check_redis_availability("localhost", 6379): print("Redis服务不可用,尝试重新连接...") # 执行重试逻辑或切换到备用存储 time.sleep(5)
在记忆读取和写入操作中,也可以添加重试逻辑,当出现连接失败或超时等错误时,自动进行重试:
import redisfrom langchain.memory import RedisChatMessageHistoryimport timeredis_client = redis.Redis(host='localhost', port=6379, db=0)history = RedisChatMessageHistory(redis_client=redis_client, session_id="my_session")max_retries = 3retry_delay = 2for attempt in range(max_retries): try: history.add_user_message("测试消息") break except redis.ConnectionError: if attempt < max_retries - 1: print(f"连接Redis失败,第{attempt + 1}次重试,等待{retry_delay}秒...") time.sleep(retry_delay) else: print("多次重试失败,无法保存记忆")
X. 安全与隐私保护机制
10.1 数据加密
在跨会话记忆恢复过程中,用户的对话数据可能包含敏感信息,因此数据加密至关重要。LangChain支持对存储的记忆数据进行加密处理。以文件存储为例,使用cryptography
库对聊天消息历史文件进行加密:
from cryptography.fernet import Fernetfrom langchain.memory import ChatMessageHistoryimport os# 生成加密密钥key = Fernet.generate_key()cipher_suite = Fernet(key)# 初始化聊天消息历史history = ChatMessageHistory()history.add_user_message("敏感信息内容")history.add_ai_message("相关回复")# 将消息历史序列化为字符串messages_str = str([{"type": type(msg).__name__, "content": msg.content} for msg in history.messages])# 加密消息字符串encrypted_messages = cipher_suite.encrypt(messages_str.encode())# 保存加密后的消息到文件with open("encrypted_chat_history.bin", "wb") as f: f.write(encrypted_messages)# 读取加密文件并解密with open("encrypted_chat_history.bin", "rb") as f: read_encrypted_messages = f.read()decrypted_messages = cipher_suite.decrypt(read_encrypted_messages).decode()print(decrypted_messages)
对于数据库存储,如SQLite,可以在写入数据前进行加密,读取数据时进行解密。在SQLiteChatMessageHistory
类中可以添加加密解密方法:
import sqlite3from cryptography.fernet import Fernetfrom langchain.memory import ChatMessageHistoryclass EncryptedSQLiteChatMessageHistory(ChatMessageHistory): def __init__(self, db_path: str, session_id: str, key: bytes): super().__init__() self.db_path = db_path self.session_id = session_id self.cipher_suite = Fernet(key) self._create_table() self._load_from_db() def _encrypt_message(self, message: str) -> bytes: return self.cipher_suite.encrypt(message.encode()) def _decrypt_message(self, encrypted_message: bytes) -> str: return self.cipher_suite.decrypt(encrypted_message).decode() def _load_from_db(self) -> None: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT message_type, content FROM chat_messages WHERE session_id = ? ORDER BY timestamp ASC", (self.session_id,) ) rows = cursor.fetchall() conn.close() for msg_type, encrypted_content in rows: content = self._decrypt_message(encrypted_content) if msg_type == "HumanMessage": self.add_user_message(content) elif msg_type == "AIMessage": self.add_ai_message(content) def add_user_message(self, message: str) -> None: super().add_user_message(message) self._save_to_db("HumanMessage", self._encrypt_message(message)) def add_ai_message(self, message: str) -> None: super().add_ai_message(message) self._save_to_db("AIMessage", self._encrypt_message(message)) def _save_to_db(self, message_type: str, encrypted_content: bytes) -> None: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "INSERT INTO chat_messages (session_id, message_type, content) VALUES (?, ?, ?)", (self.session_id, message_type, encrypted_content) ) conn.commit() conn.close()
对于向量存储中的嵌入向量,虽然直接加密向量会影响相似度计算,但可以对向量的元数据进行加密,保护与向量相关的敏感信息。
10.2 访问控制
LangChain通过访问控制机制确保只有授权的用户或服务能够访问和操作记忆数据。可以基于角色的访问控制(RBAC)来实现,定义不同角色及其对应的权限:
from enum import Enumclass Role(Enum): ADMIN = "admin" USER = "user" GUEST = "guest"class MemoryAccessControl: def __init__(self): self.role_permissions = { Role.ADMIN: {"read": True, "write": True, "delete": True}, Role.USER: {"read": True, "write": True, "delete": False}, Role.GUEST: {"read": True, "write": False, "delete": False} } def has_permission(self, role: Role, operation: str) -> bool: return self.role_permissions[role].get(operation, False) def check_permission(self, role: Role, operation: str): if not self.has_permission(role, operation): raise PermissionError(f"角色{role.value}没有{operation}权限")# 使用示例access_control = MemoryAccessControl()user_role = Role.USERaccess_control.check_permission(user_role, "read")try: access_control.check_permission(user_role, "delete")except PermissionError as e: print(e)
在实际应用中,将访问控制与记忆操作相结合。例如,在RedisChatMessageHistory
类中添加权限检查:
import redisfrom langchain.memory import ChatMessageHistoryfrom typing import Unionclass AccessControlledRedisChatMessageHistory(ChatMessageHistory): def __init__(self, redis_client: redis.Redis, session_id: str, role: Union[Role, str]): super().__init__() self.redis_client = redis_client self.session_id = session_id self.role = Role(role) if isinstance(role, str) else role self.access_control = MemoryAccessControl() self._load_from_redis() def _load_from_redis(self) -> None: self.access_control.check_permission(self.role, "read") key = f"chat_history:{self.session_id}" messages_json = self.redis_client.get(key) if messages_json: messages_dict = eval(messages_json) for msg_dict in messages_dict: msg_type = msg_dict["type"] content = msg_dict["content"] if msg_type == "HumanMessage": self.add_user_message(content) elif msg_type == "AIMessage": self.add_ai_message(content) def add_user_message(self, message: str) -> None: self.access_control.check_permission(self.role, "write") super().add_user_message(message) self._save_to_redis() def add_ai_message(self, message: str) -> None: self.access_control.check_permission(self.role, "write") super().add_ai_message(message) self._save_to_redis() def _save_to_redis(self) -> None: key = f"chat_history:{self.session_id}" messages_dict = [{"type": type(msg).__name__, "content": msg.content} for msg in self.messages] self.redis_client.set(key, str(messages_dict))
通过这种方式,不同角色的用户对记忆数据的操作受到严格限制,保障数据安全。
10.3 隐私数据脱敏
除了加密和访问控制,对隐私数据进行脱敏处理也是保护用户隐私的重要手段。在记忆存储和处理过程中,识别并替换敏感信息。例如,使用正则表达式对聊天消息中的邮箱、手机号等信息进行脱敏:
import refrom langchain.memory import ChatMessageHistoryclass PrivacyProtectedChatMessageHistory(ChatMessageHistory): def __init__(self): super().__init__() self.email_pattern = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b") self.phone_pattern = re.compile(r"\b(?:\+?\d{1,3}[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b") def _redact_message(self, message: str) -> str: redacted_message = self.email_pattern.sub("[REDACTED_EMAIL]", message) redacted_message = self.phone_pattern.sub("[REDACTED_PHONE]", redacted_message) return redacted_message def add_user_message(self, message: str) -> None: redacted_message = self._redact_message(message) super().add_user_message(redacted_message) def add_ai_message(self, message: str) -> None: redacted_message = self._redact_message(message) super().add_ai_message(redacted_message)
在向量存储的元数据处理中,同样可以进行隐私数据脱敏,避免敏感信息泄露,全方位保护用户隐私 。
以上从多个维度深入解析了LangChain跨会话记忆恢复技术。如果你还想了解某部分的拓展内容,或对其他技术细节感兴趣,欢迎随时和我说。