Magnetic-UI源码解析
Teammanager
的实现
create_team
async def _create_team( self, team_config: Union[str, Path, Dict[str, Any], ComponentModel], state: Optional[Mapping[str, Any] | str] = None, input_func: Optional[InputFuncType] = None, env_vars: Optional[List[EnvironmentVariable]] = None, settings_config: dict[str, Any] = {}, *, paths: RunPaths, ) -> tuple[Team, int, int]: """ 根据配置创建团队实例 """ # 获取浏览器资源配置(novnc端口和playwright端口) _, novnc_port, playwright_port = get_browser_resource_config( paths.external_run_dir, -1, -1, self.inside_docker ) try: # 如果从配置文件加载 if self.load_from_config: # 当从配置文件加载时,使用get_task_team而不是team_config logger.info("Loading team from configuration file") # 解析settings_config中的模型配置(如果可用) settings_model_configs: Dict[str, Any] = {} if "model_configs" in settings_config: try: settings_model_configs = yaml.safe_load( settings_config["model_configs"] ) except Exception as e: logger.warning( f"Error loading model configs from UI. Using defaults. Inner exception: {e}" ) # 使用settings_config值(如果可用),否则回退到实例默认值(self.config) logger.info(f"Debug - self.config keys: {list(self.config.keys()) if self.config else 'None'}") logger.info(f"Debug - settings_model_configs: {settings_model_configs}") logger.info(f"Debug - load_from_config: {self.load_from_config}") # 重要:当load_from_config为True时,优先使用文件配置而非UI设置 logger.info("Prioritizing config file over UI settings") # 优先使用文件配置,仅当文件配置为None时才回退到设置 orchestrator_config = ( self.config.get("orchestrator_client") or settings_model_configs.get("orchestrator_client") ) web_surfer_config = ( self.config.get("web_surfer_client") or settings_model_configs.get("web_surfer_client") ) coder_config = ( self.config.get("coder_client") or settings_model_configs.get("coder_client") ) file_surfer_config = ( self.config.get("file_surfer_client") or settings_model_configs.get("file_surfer_client") ) action_guard_config = ( self.config.get("action_guard_client") or settings_model_configs.get("action_guard_client") ) logger.info(f"Debug - final orchestrator_config: {orchestrator_config}") # 创建模型客户端配置对象 model_client_configs = ModelClientConfigs( orchestrator=orchestrator_config, web_surfer=web_surfer_config, coder=coder_config, file_surfer=file_surfer_config, action_guard=action_guard_config, ) # 创建MagenticUI配置对象 magentic_ui_config = MagenticUIConfig( **{ # 最低优先级默认值 **self.config, # type: ignore # 提供的设置覆盖默认值 **settings_config, # type: ignore, # 设置为手动合并的字典 "model_client_configs": model_client_configs, # 这些必须始终设置为上面计算的值 "playwright_port": playwright_port, "novnc_port": novnc_port, # 回退到self获取inside_docker值 "inside_docker": self.inside_docker, } ) # 获取任务团队实例 self.team = cast( Team, await get_task_team( magentic_ui_config=magentic_ui_config, input_func=input_func, paths=paths, ), ) # 如果团队有参与者,查找WebSurfer代理并获取端口信息 if hasattr(self.team, "_participants"): for agent in cast(list[ChatAgent], self.team._participants): # type: ignore if isinstance(agent, WebSurfer): novnc_port = agent.novnc_port playwright_port = agent.playwright_port # 如果有状态信息,则加载状态 if state: if isinstance(state, str): try: # 尝试解压缩状态(如果是压缩的) state_dict = decompress_state(state) await self.team.load_state(state_dict) except Exception: # 如果解压缩失败,假设它是常规JSON字符串 state_dict = json.loads(state) await self.team.load_state(state_dict) else: await self.team.load_state(state) return self.team, novnc_port, playwright_port # 原始team_config加载逻辑(非配置文件模式) else: logger.info("Loading team from team_config (UI mode)") # settings_config字典在键`model_configs`中提供模型配置 # 但MagenticUIConfig期望`model_client_configs`,所以需要在此处更新 settings_model_configs: Dict[str, Any] = {} if "model_configs" in settings_config: try: settings_model_configs = yaml.safe_load( settings_config["model_configs"] ) except Exception as e: logger.warning( f"Error loading model configs from UI. Using defaults. Inner exception: {e}" ) # 使用UI设置作为优先级,回退到文件配置 orchestrator_config = settings_model_configs.get( "orchestrator_client", self.config.get("orchestrator_client", None), ) web_surfer_config = settings_model_configs.get( "web_surfer_client", self.config.get("web_surfer_client", None), ) coder_config = settings_model_configs.get( "coder_client", self.config.get("coder_client", None) ) file_surfer_config = settings_model_configs.get( "file_surfer_client", self.config.get("file_surfer_client", None), ) action_guard_config = settings_model_configs.get( "action_guard_client", self.config.get("action_guard_client", None), ) # 创建模型客户端配置对象 model_client_configs = ModelClientConfigs( orchestrator=orchestrator_config, web_surfer=web_surfer_config, coder=coder_config, file_surfer=file_surfer_config, action_guard=action_guard_config, ) # 创建MagenticUI配置对象 magentic_ui_config = MagenticUIConfig( **{ # 最低优先级默认值 **self.config, # type: ignore # 提供的设置覆盖默认值 **settings_config, # type: ignore, # 设置为手动合并的字典 "model_client_configs": model_client_configs, # 这些必须始终设置为上面计算的值 "playwright_port": playwright_port, "novnc_port": novnc_port, # 回退到self获取inside_docker值 "inside_docker": self.inside_docker, } ) # 获取任务团队实例 self.team = cast( Team, await get_task_team( magentic_ui_config=magentic_ui_config, input_func=input_func, paths=paths, ), ) # 如果团队有参与者,查找WebSurfer代理并获取端口信息 if hasattr(self.team, "_participants"): for agent in cast(list[ChatAgent], self.team._participants): # type: ignore if isinstance(agent, WebSurfer): novnc_port = agent.novnc_port playwright_port = agent.playwright_port # 如果有状态信息,则加载状态 if state: if isinstance(state, str): try: # 尝试解压缩状态(如果是压缩的) state_dict = decompress_state(state) await self.team.load_state(state_dict) except Exception: # 如果解压缩失败,假设它是常规JSON字符串 state_dict = json.loads(state) await self.team.load_state(state_dict) else: await self.team.load_state(state) return self.team, novnc_port, playwright_port except Exception as e: logger.error(f"Error creating team: {e}") await self.close() raise
该函数根据传入配置(配置文件或UI配置)动态创建一个“团队”(Team 实例),并返回该团队实例和与浏览器资源相关的端口信息(novnc_port, playwright_port)
支持两种方式加载配置文件:
- 1️⃣ 从配置文件加载团队(self.load_from_config = True)2️⃣从传入参数直接构建团队(UI 模式)
run_stream
async def run_stream( self, task: Optional[Union[ChatMessage, str, Sequence[ChatMessage]]], team_config: Union[str, Path, dict[str, Any], ComponentModel], state: Optional[Mapping[str, Any] | str] = None, input_func: Optional[InputFuncType] = None, cancellation_token: Optional[CancellationToken] = None, env_vars: Optional[List[EnvironmentVariable]] = None, settings_config: Optional[Dict[str, Any]] = None, run: Optional[Run] = None, ) -> AsyncGenerator[ Union[AgentEvent, ChatMessage, LLMCallEventMessage, TeamResult], None ]: """ 流式返回团队执行结果 """ # 记录开始时间 start_time = time.time() # 正确设置日志记录器 logger = logging.getLogger(EVENT_LOGGER_NAME) logger.setLevel(logging.CRITICAL) # 创建运行事件日志记录器 llm_event_logger = RunEventLogger() logger.handlers = [llm_event_logger] # 替换所有处理程序 logger.info(f"Running in docker: {self.inside_docker}") # 准备运行路径 paths = self.prepare_run_paths(run=run) # 获取已知文件集合,用于跟踪新生成的文件 known_files = set( file["name"] for file in get_modified_files( 0, time.time(), source_dir=str(paths.internal_run_dir) ) ) # 存储全局新文件列表 global_new_files: List[Dict[str, str]] = [] try: # TODO: 如果我们不谨慎,这可能会导致问题 # 如果团队未初始化,则创建团队 if self.team is None: # TODO: 如果我们开始允许从配置加载,我们需要将novnc和playwright端口写回到团队配置中 _, _novnc_port, _playwright_port = await self._create_team( team_config, state, input_func, env_vars, settings_config or {}, paths=paths, ) # 通过名称初始化已知文件以进行跟踪 initial_files = get_modified_files( start_time, time.time(), source_dir=str(paths.internal_run_dir) ) known_files = {file["name"] for file in initial_files} # 生成浏览器地址信息消息 yield TextMessage( source="system", content=f"Browser noVNC address can be found at http://localhost:{_novnc_port}/vnc.html", metadata={ "internal": "no", "type": "browser_address", "novnc_port": str(_novnc_port), "playwright_port": str(_playwright_port), }, ) # 异步迭代团队运行流中的消息 async for message in self.team.run_stream( # type: ignore task=task, cancellation_token=cancellation_token ): # 检查是否需要取消操作 if cancellation_token and cancellation_token.is_cancelled(): break # 获取所有当前文件及其完整元数据 modified_files = get_modified_files( start_time, time.time(), source_dir=str(paths.internal_run_dir) ) current_file_names = {file["name"] for file in modified_files} # 查找新文件 new_file_names = current_file_names - known_files known_files = current_file_names # 为下一次迭代更新 # 获取新文件的完整数据 new_files = [ file for file in modified_files if file["name"] in new_file_names ] # 如果有新文件生成,则发送文件消息 if new_files: # 过滤以"tmp_code"开头的文件 new_files = [ file for file in new_files if not file["name"].startswith("tmp_code") ] if len(new_files) > 0: file_message = TextMessage( source="system", content="File Generated", metadata={ "internal": "no", "type": "file", "files": json.dumps(new_files), }, ) global_new_files.extend(new_files) yield file_message # 处理任务结果消息 if isinstance(message, TaskResult): yield TeamResult( task_result=message, usage="", duration=time.time() - start_time, files=modified_files, # 保留完整文件数据 ) else: yield message # 添加生成的文件到最终输出 if ( isinstance(message, TextMessage) and message.metadata.get("type", "") == "final_answer" ): # 如果有全局新文件,则发送文件消息 if len(global_new_files) > 0: # 只保留唯一的文件名,如果有相同名称的文件,保留最新的一个 global_new_files = list( { file["name"]: file for file in global_new_files }.values() ) file_message = TextMessage( source="system", content="File Generated", metadata={ "internal": "no", "type": "file", "files": json.dumps(global_new_files), }, ) yield file_message global_new_files = [] # 检查是否有LLM事件 while not llm_event_logger.events.empty(): event = await llm_event_logger.events.get() yield event finally: # 清理 - 移除我们的处理程序 if llm_event_logger in logger.handlers: logger.handlers.remove(llm_event_logger) # 确保清理发生 if self.team and hasattr(self.team, "close"): logger.info("Closing team") await self.team.close() # type: ignore logger.info("Team closed")
该函数主要功能:
run_stream 函数核心功能:
初始化团队(如未初始化)
启动任务的异步运行
监听运行过程中的输出:
- 普通 AI 消息(ChatMessage)最终结果(TaskResult → TeamResult)文件生成事件(系统消息)LLM 事件(如 Token 使用等)
支持任务取消
运行结束后进行资源清理
传递浏览器消息
# 生成浏览器地址信息消息 yield TextMessage( source="system", content=f"Browser noVNC address can be found at http://localhost:{_novnc_port}/vnc.html", metadata={ "internal": "no", "type": "browser_address", "novnc_port": str(_novnc_port), "playwright_port": str(_playwright_port), }, )
启动异步任务执行
async for message in self.team.run_stream(...):
从团队 run_stream 接收 逐条消息输出,包括中间内容、LLM回复、系统提示、最终结果。
文件监控
modified_files = get_modified_files(...)new_file_names = current_file_names - known_files # 集合求差
get_modified_files
使用 os.path.getmtime()
获取文件的最后修改时间,然后与给定的时间范围进行比较,只保留时间范围内的文件。
每次团队输出后,检查目录中是否有新增文件。如果检测到新文件,就封装为 TextMessage:
yield TextMessage(content="File Generated", metadata={...})
并添加到 global_new_files 用于最终输出。
输出最终结果(如有)
if isinstance(message, TaskResult): yield TeamResult(...)
最终文件再次确认输出
if message.metadata.get("type", "") == "final_answer": # 发送全局文件信息
LLM事件日志输出
while not llm_event_logger.events.empty(): event = await llm_event_logger.events.get() yield event
WebSocketManager
实现
WebSocketManager类主要负责管理WebSocket连接和团队任务执行的消息流。
import asyncioimport jsonimport loggingimport tracebackfrom datetime import datetime, timezonefrom pathlib import Pathfrom typing import Any, Dict, Optional, Sequence, Unionfrom autogen_agentchat.base._task import TaskResultfrom autogen_agentchat.messages import (AgentEvent, ChatMessage, HandoffMessage, ModelClientStreamingChunkEvent, MultiModalMessage, StopMessage, TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent)from autogen_core import CancellationTokenfrom fastapi import WebSocket, WebSocketDisconnectfrom ....input_func import InputFuncType, InputRequestTypefrom ....types import CheckpointEventfrom ...database import DatabaseManagerfrom ...datamodel import (LLMCallEventMessage, Message, MessageConfig, Run, RunStatus, Settings, SettingsConfig, TeamResult)from ...teammanager import TeamManagerfrom ...utils.utils import compress_statelogger = logging.getLogger(__name__)class WebSocketManager: """ 管理WebSocket连接和团队任务执行的消息流 参数: db_manager (DatabaseManager): 用于数据库操作的数据库管理器实例 internal_workspace_root (Path): 内部根目录的路径 external_workspace_root (Path): 外部根目录的路径 inside_docker (bool): 指示应用程序是否在Docker内运行的标志 config (dict): Magentic-UI的配置 """ # Manages WebSocket connections and message streaming for team task execution # # Args: # db_manager (DatabaseManager): Database manager instance for database operations # internal_workspace_root (Path): Path to the internal root directory # external_workspace_root (Path): Path to the external root directory # inside_docker (bool): Flag indicating if the application is running inside Docker # config (dict): Configuration for Magentic-UI def __init__( self, db_manager: DatabaseManager, internal_workspace_root: Path, external_workspace_root: Path, inside_docker: bool, config: Dict[str, Any], ): # 数据库管理器实例 self.db_manager = db_manager # 内部工作区根目录路径 self.internal_workspace_root = internal_workspace_root # 外部工作区根目录路径 self.external_workspace_root = external_workspace_root # 是否在Docker内部运行的标志 self.inside_docker = inside_docker # Magentic-UI配置 self.config = config # WebSocket连接字典,run_id映射到WebSocket对象 self._connections: Dict[int, WebSocket] = {} # 取消令牌字典,run_id映射到CancellationToken对象 self._cancellation_tokens: Dict[int, CancellationToken] = {} # 跟踪显式关闭的连接 # Track explicitly closed connections self._closed_connections: set[int] = set() # 输入响应队列字典,run_id映射到asyncio.Queue对象 self._input_responses: Dict[int, asyncio.Queue[str]] = {} # 团队管理器字典,run_id映射到TeamManager对象 self._team_managers: Dict[int, TeamManager] = {} # 取消消息模板 self._cancel_message = TeamResult( task_result=TaskResult( messages=[TextMessage(source="user", content="Run cancelled by user")], stop_reason="cancelled by user", ), usage="", duration=0, ).model_dump() def _get_stop_message(self, reason: str) -> dict[str, Any]: # 根据给定原因生成停止消息 return TeamResult( task_result=TaskResult( messages=[TextMessage(source="user", content=reason)], stop_reason=reason, ), usage="", duration=0, ).model_dump() async def connect(self, websocket: WebSocket, run_id: int) -> bool: # 建立WebSocket连接 try: await websocket.accept() self._connections[run_id] = websocket self._closed_connections.discard(run_id) # 为此连接初始化输入队列 # Initialize input queue for this connection self._input_responses[run_id] = asyncio.Queue() await self._send_message( run_id, { "type": "system", "status": "connected", "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return True except Exception as e: logger.error(f"Connection error for run {run_id}: {e}") return False async def start_stream( self, run_id: int, task: str | ChatMessage | Sequence[ChatMessage] | None, team_config: Dict[str, Any], settings_config: Dict[str, Any], user_settings: Settings | None = None, ) -> None: """ 开始流式任务执行并进行适当的运行管理 参数: run_id (int): 运行的ID task (str | ChatMessage | Sequence[ChatMessage] | None): 要执行的任务 team_config (Dict[str, Any]): 团队的配置 settings_config (Dict[str, Any]): 设置的配置 user_settings (Settings, optional): 运行的用户设置 """ # Start streaming task execution with proper run management # # Args: # run_id (int): ID of the run # task (str | ChatMessage | Sequence[ChatMessage] | None): Task to execute # team_config (Dict[str, Any]): Configuration for the team # settings_config (Dict[str, Any]): Configuration for settings # user_settings (Settings, optional): User settings for the run if run_id not in self._connections or run_id in self._closed_connections: raise ValueError(f"No active connection for run {run_id}") # 如果已存在团队管理器,则不创建新的 # do not create a new team manager if one already exists if run_id not in self._team_managers: team_manager = TeamManager( internal_workspace_root=self.internal_workspace_root, external_workspace_root=self.external_workspace_root, inside_docker=self.inside_docker, config=self.config, ) self._team_managers[run_id] = team_manager else: team_manager = self._team_managers[run_id] cancellation_token = CancellationToken() self._cancellation_tokens[run_id] = cancellation_token final_result = None try: # 使用任务和状态更新运行 # Update run with task and status run = await self._get_run(run_id) assert run is not None, f"Run {run_id} not found in database" assert run.user_id is not None, f"Run {run_id} has no user ID" # 获取用户设置 # Get user Settings user_settings = await self._get_settings(run.user_id) env_vars = ( SettingsConfig(**user_settings.config).environment # type: ignore if user_settings else None ) settings_config["memory_controller_key"] = run.user_id state = None if run: run.task = MessageConfig(content=task, source="user").model_dump() run.status = RunStatus.ACTIVE state = run.state self.db_manager.upsert(run) await self._update_run_status(run_id, RunStatus.ACTIVE) # 将任务添加为消息 # add task as message if isinstance(task, str): await self._send_message( run_id, self._format_message(TextMessage(source="user_proxy", content=task)) or {}, ) await self._save_message( run_id, TextMessage(source="user_proxy", content=task) ) elif isinstance(task, Sequence): for task_message in task: if isinstance(task_message, TextMessage) or isinstance( task_message, MultiModalMessage ): if ( hasattr(task_message, "metadata") and task_message.metadata.get("internal") == "yes" ): continue await self._send_message( run_id, self._format_message(task_message) or {} ) await self._save_message(run_id, task_message) input_func: InputFuncType = self.create_input_func(run_id) message: ChatMessage | AgentEvent | TeamResult | LLMCallEventMessage async for message in team_manager.run_stream( task=task, team_config=team_config, state=state, input_func=input_func, cancellation_token=cancellation_token, env_vars=env_vars, settings_config=settings_config, run=run, ): if ( cancellation_token.is_cancelled() or run_id in self._closed_connections ): logger.info( f"Stream cancelled or connection closed for run {run_id}" ) break if isinstance(message, CheckpointEvent): # 保存状态到运行中 # Save state to run run = await self._get_run(run_id) if run: # 使用compress_state工具压缩状态 # Use compress_state utility to compress the state state_dict = json.loads(message.state) run.state = compress_state(state_dict) self.db_manager.upsert(run) continue # 不显示内部消息 # do not show internal messages if ( hasattr(message, "metadata") and message.metadata.get("internal") == "yes" # type: ignore ): continue formatted_message = self._format_message(message) if formatted_message: await self._send_message(run_id, formatted_message) # 按具体类型保存消息 # Save messages by concrete type if isinstance( message, ( TextMessage, MultiModalMessage, StopMessage, HandoffMessage, ToolCallRequestEvent, ToolCallExecutionEvent, LLMCallEventMessage, ), ): await self._save_message(run_id, message) # 如果是TeamResult则捕获最终结果 # Capture final result if it's a TeamResult elif isinstance(message, TeamResult): final_result = message.model_dump() self._team_managers[run_id] = team_manager # 跟踪团队管理器 if ( not cancellation_token.is_cancelled() and run_id not in self._closed_connections ): if final_result: await self._update_run( run_id, RunStatus.COMPLETE, team_result=final_result ) else: logger.warning( f"No final result captured for completed run {run_id}" ) await self._update_run_status(run_id, RunStatus.COMPLETE) else: await self._send_message( run_id, { "type": "completion", "status": "cancelled", "data": self._cancel_message, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) # 使用取消结果更新运行 # Update run with cancellation result await self._update_run( run_id, RunStatus.STOPPED, team_result=self._cancel_message ) except Exception as e: logger.error(f"Stream error for run {run_id}: {e}") traceback.print_exc() await self._handle_stream_error(run_id, e) finally: self._cancellation_tokens.pop(run_id, None) self._team_managers.pop(run_id, None) # 完成后移除团队管理器 async def _save_message( self, run_id: int, message: Union[AgentEvent | ChatMessage, LLMCallEventMessage] ) -> None: """ 将消息保存到数据库 参数: run_id (int): 运行的ID message (Union[AgentEvent | ChatMessage, LLMCallEventMessage]): 要保存的消息 """ # Save a message to the database # # Args: # run_id (int): ID of the run # message (Union[AgentEvent | ChatMessage, LLMCallEventMessage]): Message to save run = await self._get_run(run_id) if run: db_message = Message( created_at=datetime.now(), session_id=run.session_id, run_id=run_id, config=message.model_dump(), user_id=run.user_id, # 从运行对象传递user_id # Pass the user_id from the run object ) self.db_manager.upsert(db_message) async def _update_run( self, run_id: int, status: RunStatus, team_result: Optional[TeamResult | Dict[str, Any]] = None, error: Optional[str] = None, ) -> None: """ 更新运行状态和结果 参数: run_id (int): 运行的ID status (RunStatus): 要设置的新状态 team_result (TeamResult | dict[str, Any], optional): 可选的团队结果设置 error (str, optional): 可选的错误消息 """ # Update run status and result # # Args: # run_id (int): ID of the run # status (RunStatus): New status to set # team_result (TeamResult | dict[str, Any], optional): Optional team result to set # error (str, optional): Optional error message run = await self._get_run(run_id) if run: run.status = status if team_result: run.team_result = team_result if error: run.error_message = error self.db_manager.upsert(run) def create_input_func(self, run_id: int, timeout: int = 600) -> InputFuncType: """ 为特定运行创建输入函数 参数: run_id (int): 运行的ID timeout (int, optional): 输入响应的超时时间(秒)。默认值:600 返回: InputFuncType: 运行的输入函数 """ # Creates an input function for a specific run # # Args: # run_id (int): ID of the run # timeout (int, optional): Timeout for input response in seconds. Default: 600 # Returns: # InputFuncType: Input function for the run async def input_handler( prompt: str = "", cancellation_token: Optional[CancellationToken] = None, input_type: InputRequestType = "text_input", ) -> str: try: # 如果运行已暂停则恢复运行 # resume run if it is paused await self.resume_run(run_id) # 将运行状态更新为等待输入 # update run status to awaiting_input await self._update_run_status(run_id, RunStatus.AWAITING_INPUT) # 向客户端发送输入请求 # Send input request to client logger.info( f"Sending input request for run {run_id}: ({input_type}) {prompt}" ) await self._send_message( run_id, { "type": "input_request", "input_type": input_type, "prompt": prompt, "data": {"source": "system", "content": prompt}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) # 在Run对象中存储input_request # Store input_request in the Run object run = await self._get_run(run_id) if run: run.input_request = {"prompt": prompt, "input_type": input_type} self.db_manager.upsert(run) # 等待带有超时的响应 # Wait for response with timeout if run_id in self._input_responses: try: async def poll_for_response(): while True: # 检查运行是否已关闭/取消 # Check if run was closed/cancelled if run_id in self._closed_connections: raise ValueError("Run was closed") # 检查取消令牌是否已设置并已取消 # Check if cancellation token is set and cancelled if cancellation_token and hasattr(cancellation_token, 'is_cancelled'): try: if cancellation_token.is_cancelled(): raise ValueError("Run was cancelled") except Exception: # 如果检查取消状态失败,则假定已取消 # If checking cancellation status fails, assume it's cancelled raise ValueError("Run was cancelled") # 尝试使用短超时获取响应 # Try to get response with short timeout try: response = await asyncio.wait_for( self._input_responses[run_id].get(), timeout=min(timeout, 5), ) await self._update_run_status( run_id, RunStatus.ACTIVE ) return response except asyncio.TimeoutError: continue # 继续检查关闭状态 response = await asyncio.wait_for( poll_for_response(), timeout=timeout ) return response except asyncio.TimeoutError: # 如果发生超时则停止运行 # Stop the run if timeout occurs logger.warning(f"Input response timeout for run {run_id}") await self.stop_run( run_id, "Magentic-UI timed out while waiting for your input. To resume, please enter a follow-up message in the input box or you can simply type 'continue'.", ) raise else: raise ValueError(f"No input queue for run {run_id}") except Exception as e: logger.error(f"Error handling input for run {run_id}: {e}") raise return input_handler async def handle_input_response(self, run_id: int, response: str) -> None: # 处理来自客户端的输入响应 # Handle input response from client if run_id in self._input_responses: await self._input_responses[run_id].put(response) else: logger.warning(f"Received input response for inactive run {run_id}") async def stop_run(self, run_id: int, reason: str) -> None: # 停止运行 if run_id in self._cancellation_tokens: logger.info(f"Stopping run {run_id}") stop_message = self._get_stop_message(reason) try: # 首先更新运行记录 # Update run record first await self._update_run( run_id, status=RunStatus.STOPPED, team_result=stop_message ) # 如果连接处于活动状态,则处理websocket通信 # Then handle websocket communication if connection is active if ( run_id in self._connections and run_id not in self._closed_connections ): await self._send_message( run_id, { "type": "completion", "status": "cancelled", "data": stop_message, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) # 最后取消令牌 # Finally cancel the token cancellation_token = self._cancellation_tokens.get(run_id) if cancellation_token: try: # 检查令牌是否有取消方法 # Check if the token has a cancel method if hasattr(cancellation_token, 'cancel') and callable(cancellation_token.cancel): import inspect # 检查取消方法是否为异步 # Check if cancel method is async if inspect.iscoroutinefunction(cancellation_token.cancel): await cancellation_token.cancel() else: cancellation_token.cancel() else: logger.warning(f"Cancellation token for run {run_id} does not have a cancel method") except Exception as cancel_error: logger.warning(f"Error cancelling token for run {run_id}: {cancel_error}") # 移除团队管理器 # remove team manager team_manager = self._team_managers.pop(run_id, None) if team_manager: try: await team_manager.close() except Exception as close_error: logger.warning(f"Error closing team manager for run {run_id}: {close_error}") except Exception as e: logger.error(f"Error stopping run {run_id}: {e}") # 只有当我们不在断开连接过程中时才强制断开连接 # Only force disconnect if we're not already in disconnection process if run_id not in self._closed_connections: try: await self.disconnect(run_id) except Exception as disconnect_error: logger.error(f"Error during forced disconnect for run {run_id}: {disconnect_error}") async def disconnect(self, run_id: int) -> None: """ 清理连接和相关资源 参数: run_id (int): 要断开连接的运行ID """ # Clean up connection and associated resources # # Args: # run_id (int): ID of the run to disconnect logger.info(f"Disconnecting run {run_id}") # 在清理前标记为已关闭以防止任何新消息 # Mark as closed before cleanup to prevent any new messages self._closed_connections.add(run_id) # 只有当我们尚未开始断开连接时才停止运行 # to avoid infinite recursion between stop_run and disconnect # Only stop run if we haven't already started disconnecting # to avoid infinite recursion between stop_run and disconnect if run_id in self._cancellation_tokens: cancellation_token = self._cancellation_tokens.get(run_id) if cancellation_token: try: # 检查令牌是否有取消方法 # Check if the token has a cancel method if hasattr(cancellation_token, 'cancel') and callable(cancellation_token.cancel): import inspect # 检查取消方法是否为异步 # Check if cancel method is async if inspect.iscoroutinefunction(cancellation_token.cancel): await cancellation_token.cancel() else: cancellation_token.cancel() except Exception as cancel_error: logger.warning(f"Error cancelling token during disconnect for run {run_id}: {cancel_error}") # 关闭团队管理器 # Close team manager team_manager = self._team_managers.pop(run_id, None) if team_manager: try: await team_manager.close() except Exception as close_error: logger.warning(f"Error closing team manager during disconnect for run {run_id}: {close_error}") # 清理资源 # Clean up resources self._connections.pop(run_id, None) self._cancellation_tokens.pop(run_id, None) self._input_responses.pop(run_id, None) async def _send_message(self, run_id: int, message: Dict[str, Any]) -> None: """通过WebSocket发送消息并检查连接状态 参数: run_id (int): 运行的ID message (Dict[str, Any]): 要发送的消息字典 """ # Send a message through the WebSocket with connection state checking # # Args: # run_id (int): int of the run # message (Dict[str, Any]): Message dictionary to send if run_id in self._closed_connections: logger.warning( f"Attempted to send message to closed connection for run {run_id}" ) return try: if run_id in self._connections: websocket = self._connections[run_id] await websocket.send_json(message) except WebSocketDisconnect: logger.warning( f"WebSocket disconnected while sending message for run {run_id}" ) await self.disconnect(run_id) except Exception as e: logger.error(f"Error sending message for run {run_id}: {e}, {message}") # 不要在此处尝试发送错误消息以避免潜在的递归循环 # Don't try to send error message here to avoid potential recursive loop await self._update_run_status(run_id, RunStatus.ERROR, str(e)) await self.disconnect(run_id) async def _handle_stream_error(self, run_id: int, error: Exception) -> None: """ 处理流错误并正确更新运行 参数: run_id (int): 运行的ID error (Exception): 发生的异常 """ # Handle stream errors with proper run updates # # Args: # run_id (int): ID of the run # error (Exception): Exception that occurred if run_id not in self._closed_connections: error_result = TeamResult( task_result=TaskResult( messages=[TextMessage(source="system", content=str(error))], stop_reason="An error occurred while processing this run", ), usage="", duration=0, ).model_dump() await self._send_message( run_id, { "type": "completion", "status": "error", "data": error_result, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) await self._update_run( run_id, RunStatus.ERROR, team_result=error_result, error=str(error) ) def _format_message(self, message: Any) -> Optional[Dict[str, Any]]: """格式化用于WebSocket传输的消息 参数: message (Any): 要格式化的消息 返回: Optional[Dict[str, Any]]: 格式化的消息,如果格式化失败则返回None """ # Format message for WebSocket transmission # # Args: # message (Any): Message to format # # Returns: # Optional[Dict[str, Any]]: Formatted message or None if formatting fails try: if isinstance(message, MultiModalMessage): message_dump = message.model_dump() message_content: list[dict[str, Any]] = [] for row in message_dump["content"]: if "data" in row: message_content.append( { "url": f"data:image/png;base64,{row['data']}", "alt": "WebSurfer Screenshot", } ) else: message_content.append(row) message_dump["content"] = message_content return {"type": "message", "data": message_dump} elif isinstance(message, TeamResult): return { "type": "result", "data": message.model_dump(), "status": "complete", } elif isinstance(message, ModelClientStreamingChunkEvent): return {"type": "message_chunk", "data": message.model_dump()} elif isinstance( message, (TextMessage,), ): return {"type": "message", "data": message.model_dump()} elif isinstance(message, str): return { "type": "message", "data": {"source": "user", "content": message}, } return None except Exception as e: logger.error(f"Message formatting error: {e}") return None async def _get_run(self, run_id: int) -> Optional[Run]: """从数据库获取运行 参数: run_id (int): 要检索的运行ID 返回: Optional[Run]: 如果找到则返回Run对象,否则返回None """ # Get run from database # # Args: # run_id (int): int of the run to retrieve # # Returns: # Optional[Run]: Run object if found, None otherwise response = self.db_manager.get(Run, filters={"id": run_id}, return_json=False) return response.data[0] if response.status and response.data else None async def _get_settings(self, user_id: str) -> Optional[Settings]: """从数据库获取用户设置 参数: user_id (str): 要检索设置的用户ID 返回: Optional[Settings]: 如果找到则返回用户设置,否则返回None """ # Get user settings from database # Args: # user_id (str): User ID to retrieve settings for # Returns: # Optional[Settings]: User settings if found, None otherwise response = self.db_manager.get( filters={"user_id": user_id}, model_class=Settings, return_json=False ) return response.data[0] if response.status and response.data else None async def _update_run_status( self, run_id: int, status: RunStatus, error: Optional[str] = None ) -> None: """更新数据库中的运行状态 参数: run_id (int): 要更新的运行ID status (RunStatus): 要设置的新状态 error (str, optional): 可选的错误消息 """ # Update run status in database # # Args: # run_id (int): int of the run to update # status (RunStatus): New status to set # error (str, optional): Optional error message run = await self._get_run(run_id) if run: run.status = status run.error_message = error self.db_manager.upsert(run) # 向客户端发送带有状态的系统消息 # send system message to client with status await self._send_message( run_id, { "type": "system", "status": status, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) async def cleanup(self) -> None: """服务器关闭时清理所有活动连接和资源""" # Clean up all active connections and resources when server is shutting down logger.info(f"Cleaning up {len(self.active_connections)} active connections") try: # 首先取消所有正在运行的任务 # First cancel all running tasks for run_id in self.active_runs.copy(): if run_id in self._cancellation_tokens: self._cancellation_tokens[run_id].cancel() run = await self._get_run(run_id) if run and run.status == RunStatus.ACTIVE: interrupted_result = TeamResult( task_result=TaskResult( messages=[ TextMessage( source="system", content="Run interrupted by server shutdown", ) ], stop_reason="server_shutdown", ), usage="", duration=0, ).model_dump() run.status = RunStatus.STOPPED run.team_result = interrupted_result self.db_manager.upsert(run) # 然后在超时时间内断开所有websocket连接 # 10秒超时用于整个清理过程 # Then disconnect all websockets with timeout # 10 second timeout for entire cleanup async def disconnect_all(): for run_id in self.active_connections.copy(): try: await asyncio.wait_for(self.disconnect(run_id), timeout=2) except asyncio.TimeoutError: logger.warning(f"Timeout disconnecting run {run_id}") except Exception as e: logger.error(f"Error disconnecting run {run_id}: {e}") await asyncio.wait_for(disconnect_all(), timeout=10) except asyncio.TimeoutError: logger.warning("WebSocketManager cleanup timed out") except Exception as e: logger.error(f"Error during WebSocketManager cleanup: {e}") finally: # 始终清除内部状态,即使清理出现错误 # Always clear internal state, even if cleanup had errors self._connections.clear() self._cancellation_tokens.clear() self._closed_connections.clear() self._input_responses.clear() @property def active_connections(self) -> set[int]: """获取活动运行ID集合""" # Get set of active run IDs return set(self._connections.keys()) - self._closed_connections @property def active_runs(self) -> set[int]: """获取具有活动取消令牌的运行集合""" # Get set of runs with active cancellation tokens return set(self._cancellation_tokens.keys()) async def pause_run(self, run_id: int) -> None: """暂停运行""" # Pause the run if ( run_id in self._connections and run_id not in self._closed_connections and run_id in self._team_managers ): team_manager = self._team_managers.get(run_id) if team_manager: await team_manager.pause_run() # await self._send_message( # run_id, # { # "type": "system", # "status": "paused", # "timestamp": datetime.now(timezone.utc).isoformat(), # }, # ) # await self._update_run_status(run_id, RunStatus.PAUSED) async def resume_run(self, run_id: int) -> None: """恢复运行""" # Resume the run if ( run_id in self._connections and run_id not in self._closed_connections and run_id in self._team_managers ): team_manager = self._team_managers.get(run_id) if team_manager: await team_manager.resume_run() await self._update_run_status(run_id, RunStatus.ACTIVE)
connect
async def connect(self, websocket: WebSocket, run_id: int) -> bool: # 建立WebSocket连接 try: await websocket.accept() self._connections[run_id] = websocket self._closed_connections.discard(run_id) # 为此连接初始化输入队列 # Initialize input queue for this connection self._input_responses[run_id] = asyncio.Queue() await self._send_message( run_id, { "type": "system", "status": "connected", "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return True except Exception as e: logger.error(f"Connection error for run {run_id}: {e}") return False
处理单个 WebSocket 客户端连接的注册与初始化:记录连接、创建输入队列、发送欢迎消息。
start_stream
async def start_stream( self, run_id: int, task: str | ChatMessage | Sequence[ChatMessage] | None, team_config: Dict[str, Any], settings_config: Dict[str, Any], user_settings: Settings | None = None, ) -> None: """ 开始流式任务执行并进行适当的运行管理 参数: run_id (int): 运行的ID task (str | ChatMessage | Sequence[ChatMessage] | None): 要执行的任务 team_config (Dict[str, Any]): 团队的配置 settings_config (Dict[str, Any]): 设置的配置 user_settings (Settings, optional): 运行的用户设置 """ # Start streaming task execution with proper run management # # Args: # run_id (int): ID of the run # task (str | ChatMessage | Sequence[ChatMessage] | None): Task to execute # team_config (Dict[str, Any]): Configuration for the team # settings_config (Dict[str, Any]): Configuration for settings # user_settings (Settings, optional): User settings for the run if run_id not in self._connections or run_id in self._closed_connections: raise ValueError(f"No active connection for run {run_id}") # 如果已存在团队管理器,则不创建新的 # do not create a new team manager if one already exists if run_id not in self._team_managers: team_manager = TeamManager( internal_workspace_root=self.internal_workspace_root, external_workspace_root=self.external_workspace_root, inside_docker=self.inside_docker, config=self.config, ) self._team_managers[run_id] = team_manager else: team_manager = self._team_managers[run_id] cancellation_token = CancellationToken() self._cancellation_tokens[run_id] = cancellation_token final_result = None try: # 使用任务和状态更新运行 # Update run with task and status run = await self._get_run(run_id) assert run is not None, f"Run {run_id} not found in database" assert run.user_id is not None, f"Run {run_id} has no user ID" # 获取用户设置 # Get user Settings user_settings = await self._get_settings(run.user_id) env_vars = ( SettingsConfig(**user_settings.config).environment # type: ignore if user_settings else None ) settings_config["memory_controller_key"] = run.user_id state = None if run: run.task = MessageConfig(content=task, source="user").model_dump() run.status = RunStatus.ACTIVE state = run.state self.db_manager.upsert(run) await self._update_run_status(run_id, RunStatus.ACTIVE) # 将任务添加为消息 # add task as message if isinstance(task, str): await self._send_message( run_id, self._format_message(TextMessage(source="user_proxy", content=task)) or {}, ) await self._save_message( run_id, TextMessage(source="user_proxy", content=task) ) elif isinstance(task, Sequence): for task_message in task: if isinstance(task_message, TextMessage) or isinstance( task_message, MultiModalMessage ): if ( hasattr(task_message, "metadata") and task_message.metadata.get("internal") == "yes" ): continue await self._send_message( run_id, self._format_message(task_message) or {} ) await self._save_message(run_id, task_message) input_func: InputFuncType = self.create_input_func(run_id) message: ChatMessage | AgentEvent | TeamResult | LLMCallEventMessage async for message in team_manager.run_stream( task=task, team_config=team_config, state=state, input_func=input_func, cancellation_token=cancellation_token, env_vars=env_vars, settings_config=settings_config, run=run, ): if ( cancellation_token.is_cancelled() or run_id in self._closed_connections ): logger.info( f"Stream cancelled or connection closed for run {run_id}" ) break if isinstance(message, CheckpointEvent): # 保存状态到运行中 # Save state to run run = await self._get_run(run_id) if run: # 使用compress_state工具压缩状态 # Use compress_state utility to compress the state state_dict = json.loads(message.state) run.state = compress_state(state_dict) self.db_manager.upsert(run) continue # 不显示内部消息 # do not show internal messages if ( hasattr(message, "metadata") and message.metadata.get("internal") == "yes" # type: ignore ): continue formatted_message = self._format_message(message) if formatted_message: await self._send_message(run_id, formatted_message) # 按具体类型保存消息 # Save messages by concrete type if isinstance( message, ( TextMessage, MultiModalMessage, StopMessage, HandoffMessage, ToolCallRequestEvent, ToolCallExecutionEvent, LLMCallEventMessage, ), ): await self._save_message(run_id, message) # 如果是TeamResult则捕获最终结果 # Capture final result if it's a TeamResult elif isinstance(message, TeamResult): final_result = message.model_dump() self._team_managers[run_id] = team_manager # 跟踪团队管理器 if ( not cancellation_token.is_cancelled() and run_id not in self._closed_connections ): if final_result: await self._update_run( run_id, RunStatus.COMPLETE, team_result=final_result ) else: logger.warning( f"No final result captured for completed run {run_id}" ) await self._update_run_status(run_id, RunStatus.COMPLETE) else: await self._send_message( run_id, { "type": "completion", "status": "cancelled", "data": self._cancel_message, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) # 使用取消结果更新运行 # Update run with cancellation result await self._update_run( run_id, RunStatus.STOPPED, team_result=self._cancel_message ) except Exception as e: logger.error(f"Stream error for run {run_id}: {e}") traceback.print_exc() await self._handle_stream_error(run_id, e) finally: self._cancellation_tokens.pop(run_id, None) self._team_managers.pop(run_id, None) # 完成后移除团队管理器
start_stream() 负责启动一个指定任务的 AI 团队流式运行流程:准备数据 → 执行 → 捕获输出 → 监控状态 → 推送结果。
检查连接是否存在
if run_id not in self._connections or run_id in self._closed_connections: raise ValueError(f"No active connection for run {run_id}")
确保前端已经通过 WebSocket 成功连接。
准备 TeamManager 实例
if run_id not in self._team_managers: team_manager = TeamManager(...)
- 每个 run_id 对应一个团队实例避免重复创建,确保复用已有管理器
准备取消控制器
cancellation_token = CancellationToken()self._cancellation_tokens[run_id] = cancellation_token
获取运行状态与用户配置
run = await self._get_run(run_id)user_settings = await self._get_settings(run.user_id)env_vars = SettingsConfig(...).environment
- 从数据库加载当前运行(如先前状态、用户ID)读取用户配置,生成环境变量(用于模型或代码执行环境)
更新数据库运行状态并保存初始任务
run.task = ...run.status = RunStatus.ACTIVE
同时向前端推送消息:
await self._send_message(run_id, TextMessage(...))await self._save_message(run_id, ...)
创建输入函数,用于团队运行时交互
input_func: InputFuncType = self.create_input_func(run_id)
供运行中需要用户反馈的代理使用。
create_input_func
def create_input_func(self, run_id: int, timeout: int = 600) -> InputFuncType: """ 为特定运行创建输入函数 参数: run_id (int): 运行的ID timeout (int, optional): 输入响应的超时时间(秒)。默认值:600 返回: InputFuncType: 运行的输入函数 """ # Creates an input function for a specific run # # Args: # run_id (int): ID of the run # timeout (int, optional): Timeout for input response in seconds. Default: 600 # Returns: # InputFuncType: Input function for the run async def input_handler( prompt: str = "", cancellation_token: Optional[CancellationToken] = None, input_type: InputRequestType = "text_input", ) -> str: try: # 如果运行已暂停则恢复运行 # resume run if it is paused await self.resume_run(run_id) # 将运行状态更新为等待输入 # update run status to awaiting_input await self._update_run_status(run_id, RunStatus.AWAITING_INPUT) # 向客户端发送输入请求 # Send input request to client logger.info( f"Sending input request for run {run_id}: ({input_type}) {prompt}" ) await self._send_message( run_id, { "type": "input_request", "input_type": input_type, "prompt": prompt, "data": {"source": "system", "content": prompt}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) # 在Run对象中存储input_request # Store input_request in the Run object run = await self._get_run(run_id) if run: run.input_request = {"prompt": prompt, "input_type": input_type} self.db_manager.upsert(run) # 等待带有超时的响应 # Wait for response with timeout if run_id in self._input_responses: try: async def poll_for_response(): while True: # 检查运行是否已关闭/取消 # Check if run was closed/cancelled if run_id in self._closed_connections: raise ValueError("Run was closed") # 检查取消令牌是否已设置并已取消 # Check if cancellation token is set and cancelled if cancellation_token and hasattr(cancellation_token, 'is_cancelled'): try: if cancellation_token.is_cancelled(): raise ValueError("Run was cancelled") except Exception: # 如果检查取消状态失败,则假定已取消 # If checking cancellation status fails, assume it's cancelled raise ValueError("Run was cancelled") # 尝试使用短超时获取响应 # Try to get response with short timeout try: response = await asyncio.wait_for( self._input_responses[run_id].get(), timeout=min(timeout, 5), ) await self._update_run_status( run_id, RunStatus.ACTIVE ) return response except asyncio.TimeoutError: continue # 继续检查关闭状态 response = await asyncio.wait_for( poll_for_response(), timeout=timeout ) return response except asyncio.TimeoutError: # 如果发生超时则停止运行 # Stop the run if timeout occurs logger.warning(f"Input response timeout for run {run_id}") await self.stop_run( run_id, "Magentic-UI timed out while waiting for your input. To resume, please enter a follow-up message in the input box or you can simply type 'continue'.", ) raise else: raise ValueError(f"No input queue for run {run_id}") except Exception as e: logger.error(f"Error handling input for run {run_id}: {e}") raise return input_handler
为指定运行(run_id)创建一个异步函数,支持在代理任务中向用户发出输入请求、等待用户响应,并处理异常或超时情况
恢复运行并更新状态
await self.resume_run(run_id)await self._update_run_status(run_id, RunStatus.AWAITING_INPUT)
- 如果之前被暂停,尝试恢复运行更改任务状态为“等待输入”(供前端显示 UI)
发送输入请求到前端
await self._send_message( run_id, { "type": "input_request", "input_type": input_type, "prompt": prompt, ... },)
- 消息类型为 "input_request"提供提示文本、输入类型(如 "text_input")
将请求记录到数据库(Run 对象中)
run.input_request = {"prompt": prompt, "input_type": input_type}self.db_manager.upsert(run)
等待用户输入(带超时)
if run_id in self._input_responses:
队列由 connect() 初始化,客户端消息处理器会通过此队列投递用户输入。
async def poll_for_response(): ... response = await asyncio.wait_for(self._input_responses[run_id].get(), ...)
启动总超时等待
response = await asyncio.wait_for(poll_for_response(), timeout=timeout)