掘金 人工智能 13小时前
Magnetic-UI源码解析
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入解析了Magnetic-UI框架中创建团队实例(Team)及执行任务(run_stream)的核心逻辑。`_create_team`函数支持通过配置文件或UI配置来初始化团队,并能根据`load_from_config`标志决定加载方式,同时整合了浏览器资源(如novnc和playwright端口)以及模型客户端配置。`run_stream`函数则负责启动团队执行任务,并以流式方式处理和输出包括AI消息、LLM调用、文件生成以及最终结果在内的各种事件。该函数还实现了对新生成文件的实时监控和收集,并能在任务执行过程中处理取消操作,最后进行资源清理。

✅ **团队实例的动态创建与配置**:`_create_team`函数是创建Magnetic-UI中核心“团队”对象(Team instance)的入口。它能够根据传入的配置信息,无论是来自本地配置文件还是直接的UI设置,灵活地初始化团队。此过程会获取必要的浏览器资源配置,如novnc和playwright的端口信息,并将这些配置与模型客户端(如orchestrator, web_surfer, coder等)的配置相结合,最终构建出完整的MagenticUIConfig对象,用于实例化Team。函数区分了从配置文件加载(`load_from_config=True`)和直接从UI参数加载两种模式,并优先使用配置文件中的设置,确保配置的一致性。

🚀 **流式任务执行与事件处理**:`run_stream`函数是执行具体任务的核心驱动。它首先确保团队实例已初始化,然后接收任务指令,并通过`self.team.run_stream`启动任务的异步执行。在此过程中,函数会迭代接收来自团队的各种事件消息,包括标准的AI消息(`ChatMessage`)、LLM调用事件(`LLMCallEventMessage`)、以及任务执行的最终结果(`TaskResult`)。这些事件会以流式方式(`yield`)传递给调用者,使得用户可以实时感知任务的进展。

📁 **文件生成监控与整合**:在任务执行过程中,`run_stream`函数会持续监控工作目录,检测是否有新文件的生成。通过比较当前文件列表与已知文件集合的差异,可以准确识别出新创建的文件。这些新生成的文件(排除临时代码文件)会被封装成带有`type: 'file'`元数据的系统消息,并实时`yield`。最终,所有在任务执行期间生成的文件都会被收集并作为`TeamResult`的一部分一并输出,为用户提供完整的执行结果。

💡 **状态管理与资源调度**:`_create_team`函数支持加载先前保存的团队状态,通过`state`参数接收(可以是字符串或字典),并调用`self.team.load_state`方法来恢复团队的运行上下文。此外,函数还负责获取并传递浏览器相关的端口信息(`novnc_port`和`playwright_port`),这些信息对于用户访问浏览器界面或进行调试至关重要,并在`run_stream`函数中通过`TextMessage`的形式对外广播。

🧹 **任务取消与资源清理**:`run_stream`函数集成了对`cancellation_token`的支持,允许在任务执行过程中随时中断。当取消信号被触发时,函数会优雅地停止任务的进一步执行。无论任务是否成功完成或被取消,`finally`块中的代码都会确保释放所有占用的资源,特别是调用`self.team.close()`来关闭团队实例,防止资源泄漏,保证系统的稳定运行。

Magnetic-UI源码解析

Teammanager 的实现

create_team

    async def _create_team(        self,        team_config: Union[str, Path, Dict[str, Any], ComponentModel],        state: Optional[Mapping[str, Any] | str] = None,        input_func: Optional[InputFuncType] = None,        env_vars: Optional[List[EnvironmentVariable]] = None,        settings_config: dict[str, Any] = {},        *,        paths: RunPaths,    ) -> tuple[Team, int, int]:        """        根据配置创建团队实例        """        # 获取浏览器资源配置(novnc端口和playwright端口)        _, novnc_port, playwright_port = get_browser_resource_config(            paths.external_run_dir, -1, -1, self.inside_docker        )        try:            # 如果从配置文件加载            if self.load_from_config:                # 当从配置文件加载时,使用get_task_team而不是team_config                logger.info("Loading team from configuration file")                                # 解析settings_config中的模型配置(如果可用)                settings_model_configs: Dict[str, Any] = {}                if "model_configs" in settings_config:                    try:                        settings_model_configs = yaml.safe_load(                            settings_config["model_configs"]                        )                    except Exception as e:                        logger.warning(                            f"Error loading model configs from UI. Using defaults. Inner exception: {e}"                        )                # 使用settings_config值(如果可用),否则回退到实例默认值(self.config)                logger.info(f"Debug - self.config keys: {list(self.config.keys()) if self.config else 'None'}")                logger.info(f"Debug - settings_model_configs: {settings_model_configs}")                logger.info(f"Debug - load_from_config: {self.load_from_config}")                                # 重要:当load_from_config为True时,优先使用文件配置而非UI设置                logger.info("Prioritizing config file over UI settings")                # 优先使用文件配置,仅当文件配置为None时才回退到设置                orchestrator_config = (                    self.config.get("orchestrator_client") or                     settings_model_configs.get("orchestrator_client")                )                web_surfer_config = (                    self.config.get("web_surfer_client") or                     settings_model_configs.get("web_surfer_client")                )                coder_config = (                    self.config.get("coder_client") or                     settings_model_configs.get("coder_client")                )                file_surfer_config = (                    self.config.get("file_surfer_client") or                     settings_model_configs.get("file_surfer_client")                )                action_guard_config = (                    self.config.get("action_guard_client") or                     settings_model_configs.get("action_guard_client")                )                                logger.info(f"Debug - final orchestrator_config: {orchestrator_config}")                                # 创建模型客户端配置对象                model_client_configs = ModelClientConfigs(                    orchestrator=orchestrator_config,                    web_surfer=web_surfer_config,                    coder=coder_config,                    file_surfer=file_surfer_config,                    action_guard=action_guard_config,                )                # 创建MagenticUI配置对象                magentic_ui_config = MagenticUIConfig(                    **{                        # 最低优先级默认值                        **self.config,  # type: ignore                        # 提供的设置覆盖默认值                        **settings_config,  # type: ignore,                        # 设置为手动合并的字典                        "model_client_configs": model_client_configs,                        # 这些必须始终设置为上面计算的值                        "playwright_port": playwright_port,                        "novnc_port": novnc_port,                        # 回退到self获取inside_docker值                        "inside_docker": self.inside_docker,                    }                )                # 获取任务团队实例                self.team = cast(                    Team,                    await get_task_team(                        magentic_ui_config=magentic_ui_config,                        input_func=input_func,                        paths=paths,                    ),                )                # 如果团队有参与者,查找WebSurfer代理并获取端口信息                if hasattr(self.team, "_participants"):                    for agent in cast(list[ChatAgent], self.team._participants):  # type: ignore                        if isinstance(agent, WebSurfer):                            novnc_port = agent.novnc_port                            playwright_port = agent.playwright_port                # 如果有状态信息,则加载状态                if state:                    if isinstance(state, str):                        try:                            # 尝试解压缩状态(如果是压缩的)                            state_dict = decompress_state(state)                            await self.team.load_state(state_dict)                        except Exception:                            # 如果解压缩失败,假设它是常规JSON字符串                            state_dict = json.loads(state)                            await self.team.load_state(state_dict)                    else:                        await self.team.load_state(state)                return self.team, novnc_port, playwright_port                        # 原始team_config加载逻辑(非配置文件模式)            else:                logger.info("Loading team from team_config (UI mode)")                # settings_config字典在键`model_configs`中提供模型配置                # 但MagenticUIConfig期望`model_client_configs`,所以需要在此处更新                settings_model_configs: Dict[str, Any] = {}                if "model_configs" in settings_config:                    try:                        settings_model_configs = yaml.safe_load(                            settings_config["model_configs"]                        )                    except Exception as e:                        logger.warning(                            f"Error loading model configs from UI. Using defaults. Inner exception: {e}"                        )                # 使用UI设置作为优先级,回退到文件配置                orchestrator_config = settings_model_configs.get(                    "orchestrator_client",                    self.config.get("orchestrator_client", None),                )                web_surfer_config = settings_model_configs.get(                    "web_surfer_client",                    self.config.get("web_surfer_client", None),                )                coder_config = settings_model_configs.get(                    "coder_client", self.config.get("coder_client", None)                )                file_surfer_config = settings_model_configs.get(                    "file_surfer_client",                    self.config.get("file_surfer_client", None),                )                action_guard_config = settings_model_configs.get(                    "action_guard_client",                    self.config.get("action_guard_client", None),                )                                # 创建模型客户端配置对象                model_client_configs = ModelClientConfigs(                    orchestrator=orchestrator_config,                    web_surfer=web_surfer_config,                    coder=coder_config,                    file_surfer=file_surfer_config,                    action_guard=action_guard_config,                )                # 创建MagenticUI配置对象                magentic_ui_config = MagenticUIConfig(                    **{                        # 最低优先级默认值                        **self.config,  # type: ignore                        # 提供的设置覆盖默认值                        **settings_config,  # type: ignore,                        # 设置为手动合并的字典                        "model_client_configs": model_client_configs,                        # 这些必须始终设置为上面计算的值                        "playwright_port": playwright_port,                        "novnc_port": novnc_port,                        # 回退到self获取inside_docker值                        "inside_docker": self.inside_docker,                    }                )                # 获取任务团队实例                self.team = cast(                    Team,                    await get_task_team(                        magentic_ui_config=magentic_ui_config,                        input_func=input_func,                        paths=paths,                    ),                )                # 如果团队有参与者,查找WebSurfer代理并获取端口信息                if hasattr(self.team, "_participants"):                    for agent in cast(list[ChatAgent], self.team._participants):  # type: ignore                        if isinstance(agent, WebSurfer):                            novnc_port = agent.novnc_port                            playwright_port = agent.playwright_port                # 如果有状态信息,则加载状态                if state:                    if isinstance(state, str):                        try:                            # 尝试解压缩状态(如果是压缩的)                            state_dict = decompress_state(state)                            await self.team.load_state(state_dict)                        except Exception:                            # 如果解压缩失败,假设它是常规JSON字符串                            state_dict = json.loads(state)                            await self.team.load_state(state_dict)                    else:                        await self.team.load_state(state)                return self.team, novnc_port, playwright_port                    except Exception as e:            logger.error(f"Error creating team: {e}")            await self.close()            raise

该函数根据传入配置(配置文件或UI配置)动态创建一个“团队”(Team 实例),并返回该团队实例和与浏览器资源相关的端口信息(novnc_port, playwright_port)

支持两种方式加载配置文件:

run_stream

    async def run_stream(        self,        task: Optional[Union[ChatMessage, str, Sequence[ChatMessage]]],        team_config: Union[str, Path, dict[str, Any], ComponentModel],        state: Optional[Mapping[str, Any] | str] = None,        input_func: Optional[InputFuncType] = None,        cancellation_token: Optional[CancellationToken] = None,        env_vars: Optional[List[EnvironmentVariable]] = None,        settings_config: Optional[Dict[str, Any]] = None,        run: Optional[Run] = None,    ) -> AsyncGenerator[        Union[AgentEvent, ChatMessage, LLMCallEventMessage, TeamResult], None    ]:        """        流式返回团队执行结果        """        # 记录开始时间        start_time = time.time()        # 正确设置日志记录器        logger = logging.getLogger(EVENT_LOGGER_NAME)        logger.setLevel(logging.CRITICAL)        # 创建运行事件日志记录器        llm_event_logger = RunEventLogger()        logger.handlers = [llm_event_logger]  # 替换所有处理程序        logger.info(f"Running in docker: {self.inside_docker}")        # 准备运行路径        paths = self.prepare_run_paths(run=run)        # 获取已知文件集合,用于跟踪新生成的文件        known_files = set(            file["name"]            for file in get_modified_files(                0, time.time(), source_dir=str(paths.internal_run_dir)            )        )        # 存储全局新文件列表        global_new_files: List[Dict[str, str]] = []        try:            # TODO: 如果我们不谨慎,这可能会导致问题            # 如果团队未初始化,则创建团队            if self.team is None:                # TODO: 如果我们开始允许从配置加载,我们需要将novnc和playwright端口写回到团队配置中                _, _novnc_port, _playwright_port = await self._create_team(                    team_config,                    state,                    input_func,                    env_vars,                    settings_config or {},                    paths=paths,                )                # 通过名称初始化已知文件以进行跟踪                initial_files = get_modified_files(                    start_time, time.time(), source_dir=str(paths.internal_run_dir)                )                known_files = {file["name"] for file in initial_files}                # 生成浏览器地址信息消息                yield TextMessage(                    source="system",                    content=f"Browser noVNC address can be found at http://localhost:{_novnc_port}/vnc.html",                    metadata={                        "internal": "no",                        "type": "browser_address",                        "novnc_port": str(_novnc_port),                        "playwright_port": str(_playwright_port),                    },                )                # 异步迭代团队运行流中的消息                async for message in self.team.run_stream(  # type: ignore                    task=task, cancellation_token=cancellation_token                ):                    # 检查是否需要取消操作                    if cancellation_token and cancellation_token.is_cancelled():                        break                    # 获取所有当前文件及其完整元数据                    modified_files = get_modified_files(                        start_time, time.time(), source_dir=str(paths.internal_run_dir)                    )                    current_file_names = {file["name"] for file in modified_files}                    # 查找新文件                    new_file_names = current_file_names - known_files                    known_files = current_file_names  # 为下一次迭代更新                    # 获取新文件的完整数据                    new_files = [                        file                        for file in modified_files                        if file["name"] in new_file_names                    ]                    # 如果有新文件生成,则发送文件消息                    if new_files:                        # 过滤以"tmp_code"开头的文件                        new_files = [                            file                            for file in new_files                            if not file["name"].startswith("tmp_code")                        ]                        if len(new_files) > 0:                            file_message = TextMessage(                                source="system",                                content="File Generated",                                metadata={                                    "internal": "no",                                    "type": "file",                                    "files": json.dumps(new_files),                                },                            )                            global_new_files.extend(new_files)                            yield file_message                    # 处理任务结果消息                    if isinstance(message, TaskResult):                        yield TeamResult(                            task_result=message,                            usage="",                            duration=time.time() - start_time,                            files=modified_files,  # 保留完整文件数据                        )                    else:                        yield message                    # 添加生成的文件到最终输出                    if (                        isinstance(message, TextMessage)                        and message.metadata.get("type", "") == "final_answer"                    ):                        # 如果有全局新文件,则发送文件消息                        if len(global_new_files) > 0:                            # 只保留唯一的文件名,如果有相同名称的文件,保留最新的一个                            global_new_files = list(                                {                                    file["name"]: file for file in global_new_files                                }.values()                            )                            file_message = TextMessage(                                source="system",                                content="File Generated",                                metadata={                                    "internal": "no",                                    "type": "file",                                    "files": json.dumps(global_new_files),                                },                            )                            yield file_message                            global_new_files = []                    # 检查是否有LLM事件                    while not llm_event_logger.events.empty():                        event = await llm_event_logger.events.get()                        yield event        finally:            # 清理 - 移除我们的处理程序            if llm_event_logger in logger.handlers:                logger.handlers.remove(llm_event_logger)            # 确保清理发生            if self.team and hasattr(self.team, "close"):                logger.info("Closing team")                await self.team.close()  # type: ignore                logger.info("Team closed")

该函数主要功能:

run_stream 函数核心功能:

    初始化团队(如未初始化)

    启动任务的异步运行

    监听运行过程中的输出:

      普通 AI 消息(ChatMessage)最终结果(TaskResult → TeamResult)文件生成事件(系统消息)LLM 事件(如 Token 使用等)

    支持任务取消

    运行结束后进行资源清理

传递浏览器消息

       # 生成浏览器地址信息消息                yield TextMessage(                    source="system",                    content=f"Browser noVNC address can be found at http://localhost:{_novnc_port}/vnc.html",                    metadata={                        "internal": "no",                        "type": "browser_address",                        "novnc_port": str(_novnc_port),                        "playwright_port": str(_playwright_port),                    },                )

启动异步任务执行

async for message in self.team.run_stream(...):

从团队 run_stream 接收 逐条消息输出,包括中间内容、LLM回复、系统提示、最终结果。

文件监控

modified_files = get_modified_files(...)new_file_names = current_file_names - known_files # 集合求差

get_modified_files使用 os.path.getmtime() 获取文件的最后修改时间,然后与给定的时间范围进行比较,只保留时间范围内的文件。

每次团队输出后,检查目录中是否有新增文件。如果检测到新文件,就封装为 TextMessage:

yield TextMessage(content="File Generated", metadata={...})

并添加到 global_new_files 用于最终输出。

输出最终结果(如有)

if isinstance(message, TaskResult):    yield TeamResult(...)

最终文件再次确认输出

if message.metadata.get("type", "") == "final_answer":    # 发送全局文件信息

LLM事件日志输出

while not llm_event_logger.events.empty():    event = await llm_event_logger.events.get()    yield event

WebSocketManager实现

WebSocketManager类主要负责管理WebSocket连接和团队任务执行的消息流。

import asyncioimport jsonimport loggingimport tracebackfrom datetime import datetime, timezonefrom pathlib import Pathfrom typing import Any, Dict, Optional, Sequence, Unionfrom autogen_agentchat.base._task import TaskResultfrom autogen_agentchat.messages import (AgentEvent, ChatMessage,                                        HandoffMessage,                                        ModelClientStreamingChunkEvent,                                        MultiModalMessage, StopMessage,                                        TextMessage, ToolCallExecutionEvent,                                        ToolCallRequestEvent)from autogen_core import CancellationTokenfrom fastapi import WebSocket, WebSocketDisconnectfrom ....input_func import InputFuncType, InputRequestTypefrom ....types import CheckpointEventfrom ...database import DatabaseManagerfrom ...datamodel import (LLMCallEventMessage, Message, MessageConfig, Run,                          RunStatus, Settings, SettingsConfig, TeamResult)from ...teammanager import TeamManagerfrom ...utils.utils import compress_statelogger = logging.getLogger(__name__)class WebSocketManager:    """    管理WebSocket连接和团队任务执行的消息流        参数:        db_manager (DatabaseManager): 用于数据库操作的数据库管理器实例        internal_workspace_root (Path): 内部根目录的路径        external_workspace_root (Path): 外部根目录的路径        inside_docker (bool): 指示应用程序是否在Docker内运行的标志        config (dict): Magentic-UI的配置    """    # Manages WebSocket connections and message streaming for team task execution    #     # Args:    #     db_manager (DatabaseManager): Database manager instance for database operations    #     internal_workspace_root (Path): Path to the internal root directory    #     external_workspace_root (Path): Path to the external root directory    #     inside_docker (bool): Flag indicating if the application is running inside Docker    #     config (dict): Configuration for Magentic-UI    def __init__(        self,        db_manager: DatabaseManager,        internal_workspace_root: Path,        external_workspace_root: Path,        inside_docker: bool,        config: Dict[str, Any],    ):        # 数据库管理器实例        self.db_manager = db_manager        # 内部工作区根目录路径        self.internal_workspace_root = internal_workspace_root        # 外部工作区根目录路径        self.external_workspace_root = external_workspace_root        # 是否在Docker内部运行的标志        self.inside_docker = inside_docker        # Magentic-UI配置        self.config = config        # WebSocket连接字典,run_id映射到WebSocket对象        self._connections: Dict[int, WebSocket] = {}        # 取消令牌字典,run_id映射到CancellationToken对象        self._cancellation_tokens: Dict[int, CancellationToken] = {}        # 跟踪显式关闭的连接        # Track explicitly closed connections        self._closed_connections: set[int] = set()        # 输入响应队列字典,run_id映射到asyncio.Queue对象        self._input_responses: Dict[int, asyncio.Queue[str]] = {}        # 团队管理器字典,run_id映射到TeamManager对象        self._team_managers: Dict[int, TeamManager] = {}        # 取消消息模板        self._cancel_message = TeamResult(            task_result=TaskResult(                messages=[TextMessage(source="user", content="Run cancelled by user")],                stop_reason="cancelled by user",            ),            usage="",            duration=0,        ).model_dump()    def _get_stop_message(self, reason: str) -> dict[str, Any]:        # 根据给定原因生成停止消息        return TeamResult(            task_result=TaskResult(                messages=[TextMessage(source="user", content=reason)],                stop_reason=reason,            ),            usage="",            duration=0,        ).model_dump()    async def connect(self, websocket: WebSocket, run_id: int) -> bool:        # 建立WebSocket连接        try:            await websocket.accept()            self._connections[run_id] = websocket            self._closed_connections.discard(run_id)            # 为此连接初始化输入队列            # Initialize input queue for this connection            self._input_responses[run_id] = asyncio.Queue()            await self._send_message(                run_id,                {                    "type": "system",                    "status": "connected",                    "timestamp": datetime.now(timezone.utc).isoformat(),                },            )            return True        except Exception as e:            logger.error(f"Connection error for run {run_id}: {e}")            return False    async def start_stream(        self,        run_id: int,        task: str | ChatMessage | Sequence[ChatMessage] | None,        team_config: Dict[str, Any],        settings_config: Dict[str, Any],        user_settings: Settings | None = None,    ) -> None:        """        开始流式任务执行并进行适当的运行管理                参数:            run_id (int): 运行的ID            task (str | ChatMessage | Sequence[ChatMessage] | None): 要执行的任务            team_config (Dict[str, Any]): 团队的配置            settings_config (Dict[str, Any]): 设置的配置            user_settings (Settings, optional): 运行的用户设置        """        # Start streaming task execution with proper run management        #         # Args:        #     run_id (int): ID of the run        #     task (str | ChatMessage | Sequence[ChatMessage] | None): Task to execute        #     team_config (Dict[str, Any]): Configuration for the team        #     settings_config (Dict[str, Any]): Configuration for settings        #     user_settings (Settings, optional): User settings for the run        if run_id not in self._connections or run_id in self._closed_connections:            raise ValueError(f"No active connection for run {run_id}")        # 如果已存在团队管理器,则不创建新的        # do not create a new team manager if one already exists        if run_id not in self._team_managers:            team_manager = TeamManager(                internal_workspace_root=self.internal_workspace_root,                external_workspace_root=self.external_workspace_root,                inside_docker=self.inside_docker,                config=self.config,            )            self._team_managers[run_id] = team_manager        else:            team_manager = self._team_managers[run_id]        cancellation_token = CancellationToken()        self._cancellation_tokens[run_id] = cancellation_token        final_result = None        try:            # 使用任务和状态更新运行            # Update run with task and status            run = await self._get_run(run_id)            assert run is not None, f"Run {run_id} not found in database"            assert run.user_id is not None, f"Run {run_id} has no user ID"            # 获取用户设置            # Get user Settings            user_settings = await self._get_settings(run.user_id)            env_vars = (                SettingsConfig(**user_settings.config).environment  # type: ignore                if user_settings                else None            )            settings_config["memory_controller_key"] = run.user_id            state = None            if run:                run.task = MessageConfig(content=task, source="user").model_dump()                run.status = RunStatus.ACTIVE                state = run.state                self.db_manager.upsert(run)                await self._update_run_status(run_id, RunStatus.ACTIVE)            # 将任务添加为消息            # add task as message            if isinstance(task, str):                await self._send_message(                    run_id,                    self._format_message(TextMessage(source="user_proxy", content=task))                    or {},                )                await self._save_message(                    run_id, TextMessage(source="user_proxy", content=task)                )            elif isinstance(task, Sequence):                for task_message in task:                    if isinstance(task_message, TextMessage) or isinstance(                        task_message, MultiModalMessage                    ):                        if (                            hasattr(task_message, "metadata")                            and task_message.metadata.get("internal") == "yes"                        ):                            continue                        await self._send_message(                            run_id, self._format_message(task_message) or {}                        )                        await self._save_message(run_id, task_message)            input_func: InputFuncType = self.create_input_func(run_id)            message: ChatMessage | AgentEvent | TeamResult | LLMCallEventMessage            async for message in team_manager.run_stream(                task=task,                team_config=team_config,                state=state,                input_func=input_func,                cancellation_token=cancellation_token,                env_vars=env_vars,                settings_config=settings_config,                run=run,            ):                if (                    cancellation_token.is_cancelled()                    or run_id in self._closed_connections                ):                    logger.info(                        f"Stream cancelled or connection closed for run {run_id}"                    )                    break                if isinstance(message, CheckpointEvent):                    # 保存状态到运行中                    # Save state to run                    run = await self._get_run(run_id)                    if run:                        # 使用compress_state工具压缩状态                        # Use compress_state utility to compress the state                        state_dict = json.loads(message.state)                        run.state = compress_state(state_dict)                        self.db_manager.upsert(run)                    continue                # 不显示内部消息                # do not show internal messages                if (                    hasattr(message, "metadata")                    and message.metadata.get("internal") == "yes"  # type: ignore                ):                    continue                formatted_message = self._format_message(message)                if formatted_message:                    await self._send_message(run_id, formatted_message)                    # 按具体类型保存消息                    # Save messages by concrete type                    if isinstance(                        message,                        (                            TextMessage,                            MultiModalMessage,                            StopMessage,                            HandoffMessage,                            ToolCallRequestEvent,                            ToolCallExecutionEvent,                            LLMCallEventMessage,                        ),                    ):                        await self._save_message(run_id, message)                    # 如果是TeamResult则捕获最终结果                    # Capture final result if it's a TeamResult                    elif isinstance(message, TeamResult):                        final_result = message.model_dump()                    self._team_managers[run_id] = team_manager  # 跟踪团队管理器            if (                not cancellation_token.is_cancelled()                and run_id not in self._closed_connections            ):                if final_result:                    await self._update_run(                        run_id, RunStatus.COMPLETE, team_result=final_result                    )                else:                    logger.warning(                        f"No final result captured for completed run {run_id}"                    )                    await self._update_run_status(run_id, RunStatus.COMPLETE)            else:                await self._send_message(                    run_id,                    {                        "type": "completion",                        "status": "cancelled",                        "data": self._cancel_message,                        "timestamp": datetime.now(timezone.utc).isoformat(),                    },                )                # 使用取消结果更新运行                # Update run with cancellation result                await self._update_run(                    run_id, RunStatus.STOPPED, team_result=self._cancel_message                )        except Exception as e:            logger.error(f"Stream error for run {run_id}: {e}")            traceback.print_exc()            await self._handle_stream_error(run_id, e)        finally:            self._cancellation_tokens.pop(run_id, None)            self._team_managers.pop(run_id, None)  # 完成后移除团队管理器    async def _save_message(        self, run_id: int, message: Union[AgentEvent | ChatMessage, LLMCallEventMessage]    ) -> None:        """        将消息保存到数据库                参数:            run_id (int): 运行的ID            message (Union[AgentEvent | ChatMessage, LLMCallEventMessage]): 要保存的消息        """        # Save a message to the database        #         # Args:        #     run_id (int): ID of the run        #     message (Union[AgentEvent | ChatMessage, LLMCallEventMessage]): Message to save        run = await self._get_run(run_id)        if run:            db_message = Message(                created_at=datetime.now(),                session_id=run.session_id,                run_id=run_id,                config=message.model_dump(),                user_id=run.user_id,  # 从运行对象传递user_id                # Pass the user_id from the run object            )            self.db_manager.upsert(db_message)    async def _update_run(        self,        run_id: int,        status: RunStatus,        team_result: Optional[TeamResult | Dict[str, Any]] = None,        error: Optional[str] = None,    ) -> None:        """        更新运行状态和结果                参数:            run_id (int): 运行的ID            status (RunStatus): 要设置的新状态            team_result (TeamResult | dict[str, Any], optional): 可选的团队结果设置            error (str, optional): 可选的错误消息        """        # Update run status and result        #         # Args:        #     run_id (int): ID of the run        #     status (RunStatus): New status to set        #     team_result (TeamResult | dict[str, Any], optional): Optional team result to set        #     error (str, optional): Optional error message        run = await self._get_run(run_id)        if run:            run.status = status            if team_result:                run.team_result = team_result            if error:                run.error_message = error            self.db_manager.upsert(run)    def create_input_func(self, run_id: int, timeout: int = 600) -> InputFuncType:        """        为特定运行创建输入函数                参数:            run_id (int): 运行的ID            timeout (int, optional): 输入响应的超时时间(秒)。默认值:600        返回:            InputFuncType: 运行的输入函数        """        # Creates an input function for a specific run        #         # Args:        #     run_id (int): ID of the run        #     timeout (int, optional): Timeout for input response in seconds. Default: 600        # Returns:        #     InputFuncType: Input function for the run        async def input_handler(            prompt: str = "",            cancellation_token: Optional[CancellationToken] = None,            input_type: InputRequestType = "text_input",        ) -> str:            try:                # 如果运行已暂停则恢复运行                # resume run if it is paused                await self.resume_run(run_id)                # 将运行状态更新为等待输入                # update run status to awaiting_input                await self._update_run_status(run_id, RunStatus.AWAITING_INPUT)                # 向客户端发送输入请求                # Send input request to client                logger.info(                    f"Sending input request for run {run_id}: ({input_type}) {prompt}"                )                await self._send_message(                    run_id,                    {                        "type": "input_request",                        "input_type": input_type,                        "prompt": prompt,                        "data": {"source": "system", "content": prompt},                        "timestamp": datetime.now(timezone.utc).isoformat(),                    },                )                # 在Run对象中存储input_request                # Store input_request in the Run object                run = await self._get_run(run_id)                if run:                    run.input_request = {"prompt": prompt, "input_type": input_type}                    self.db_manager.upsert(run)                # 等待带有超时的响应                # Wait for response with timeout                if run_id in self._input_responses:                    try:                        async def poll_for_response():                            while True:                                # 检查运行是否已关闭/取消                                # Check if run was closed/cancelled                                if run_id in self._closed_connections:                                    raise ValueError("Run was closed")                                                                # 检查取消令牌是否已设置并已取消                                # Check if cancellation token is set and cancelled                                if cancellation_token and hasattr(cancellation_token, 'is_cancelled'):                                    try:                                        if cancellation_token.is_cancelled():                                            raise ValueError("Run was cancelled")                                    except Exception:                                        # 如果检查取消状态失败,则假定已取消                                        # If checking cancellation status fails, assume it's cancelled                                        raise ValueError("Run was cancelled")                                # 尝试使用短超时获取响应                                # Try to get response with short timeout                                try:                                    response = await asyncio.wait_for(                                        self._input_responses[run_id].get(),                                        timeout=min(timeout, 5),                                    )                                    await self._update_run_status(                                        run_id, RunStatus.ACTIVE                                    )                                    return response                                except asyncio.TimeoutError:                                    continue  # 继续检查关闭状态                        response = await asyncio.wait_for(                            poll_for_response(), timeout=timeout                        )                        return response                    except asyncio.TimeoutError:                        # 如果发生超时则停止运行                        # Stop the run if timeout occurs                        logger.warning(f"Input response timeout for run {run_id}")                        await self.stop_run(                            run_id,                            "Magentic-UI timed out while waiting for your input. To resume, please enter a follow-up message in the input box or you can simply type 'continue'.",                        )                        raise                else:                    raise ValueError(f"No input queue for run {run_id}")            except Exception as e:                logger.error(f"Error handling input for run {run_id}: {e}")                raise        return input_handler    async def handle_input_response(self, run_id: int, response: str) -> None:        # 处理来自客户端的输入响应        # Handle input response from client        if run_id in self._input_responses:            await self._input_responses[run_id].put(response)        else:            logger.warning(f"Received input response for inactive run {run_id}")    async def stop_run(self, run_id: int, reason: str) -> None:        # 停止运行        if run_id in self._cancellation_tokens:            logger.info(f"Stopping run {run_id}")            stop_message = self._get_stop_message(reason)            try:                # 首先更新运行记录                # Update run record first                await self._update_run(                    run_id, status=RunStatus.STOPPED, team_result=stop_message                )                # 如果连接处于活动状态,则处理websocket通信                # Then handle websocket communication if connection is active                if (                    run_id in self._connections                    and run_id not in self._closed_connections                ):                    await self._send_message(                        run_id,                        {                            "type": "completion",                            "status": "cancelled",                            "data": stop_message,                            "timestamp": datetime.now(timezone.utc).isoformat(),                        },                    )                # 最后取消令牌                # Finally cancel the token                cancellation_token = self._cancellation_tokens.get(run_id)                if cancellation_token:                    try:                        # 检查令牌是否有取消方法                        # Check if the token has a cancel method                        if hasattr(cancellation_token, 'cancel') and callable(cancellation_token.cancel):                            import inspect                            # 检查取消方法是否为异步                            # Check if cancel method is async                            if inspect.iscoroutinefunction(cancellation_token.cancel):                                await cancellation_token.cancel()                            else:                                cancellation_token.cancel()                        else:                            logger.warning(f"Cancellation token for run {run_id} does not have a cancel method")                    except Exception as cancel_error:                        logger.warning(f"Error cancelling token for run {run_id}: {cancel_error}")                                # 移除团队管理器                # remove team manager                team_manager = self._team_managers.pop(run_id, None)                if team_manager:                    try:                        await team_manager.close()                    except Exception as close_error:                        logger.warning(f"Error closing team manager for run {run_id}: {close_error}")            except Exception as e:                logger.error(f"Error stopping run {run_id}: {e}")                # 只有当我们不在断开连接过程中时才强制断开连接                # Only force disconnect if we're not already in disconnection process                if run_id not in self._closed_connections:                    try:                        await self.disconnect(run_id)                    except Exception as disconnect_error:                        logger.error(f"Error during forced disconnect for run {run_id}: {disconnect_error}")    async def disconnect(self, run_id: int) -> None:        """        清理连接和相关资源                参数:            run_id (int): 要断开连接的运行ID        """        # Clean up connection and associated resources        #         # Args:        #     run_id (int): ID of the run to disconnect        logger.info(f"Disconnecting run {run_id}")        # 在清理前标记为已关闭以防止任何新消息        # Mark as closed before cleanup to prevent any new messages        self._closed_connections.add(run_id)        # 只有当我们尚未开始断开连接时才停止运行        # to avoid infinite recursion between stop_run and disconnect        # Only stop run if we haven't already started disconnecting        # to avoid infinite recursion between stop_run and disconnect        if run_id in self._cancellation_tokens:            cancellation_token = self._cancellation_tokens.get(run_id)            if cancellation_token:                try:                    # 检查令牌是否有取消方法                    # Check if the token has a cancel method                    if hasattr(cancellation_token, 'cancel') and callable(cancellation_token.cancel):                        import inspect                        # 检查取消方法是否为异步                        # Check if cancel method is async                        if inspect.iscoroutinefunction(cancellation_token.cancel):                            await cancellation_token.cancel()                        else:                            cancellation_token.cancel()                except Exception as cancel_error:                    logger.warning(f"Error cancelling token during disconnect for run {run_id}: {cancel_error}")        # 关闭团队管理器        # Close team manager        team_manager = self._team_managers.pop(run_id, None)        if team_manager:            try:                await team_manager.close()            except Exception as close_error:                logger.warning(f"Error closing team manager during disconnect for run {run_id}: {close_error}")        # 清理资源        # Clean up resources        self._connections.pop(run_id, None)        self._cancellation_tokens.pop(run_id, None)        self._input_responses.pop(run_id, None)    async def _send_message(self, run_id: int, message: Dict[str, Any]) -> None:        """通过WebSocket发送消息并检查连接状态                参数:            run_id (int): 运行的ID            message (Dict[str, Any]): 要发送的消息字典        """        # Send a message through the WebSocket with connection state checking        #         # Args:        #     run_id (int): int of the run        #     message (Dict[str, Any]): Message dictionary to send        if run_id in self._closed_connections:            logger.warning(                f"Attempted to send message to closed connection for run {run_id}"            )            return        try:            if run_id in self._connections:                websocket = self._connections[run_id]                await websocket.send_json(message)        except WebSocketDisconnect:            logger.warning(                f"WebSocket disconnected while sending message for run {run_id}"            )            await self.disconnect(run_id)        except Exception as e:            logger.error(f"Error sending message for run {run_id}: {e}, {message}")            # 不要在此处尝试发送错误消息以避免潜在的递归循环            # Don't try to send error message here to avoid potential recursive loop            await self._update_run_status(run_id, RunStatus.ERROR, str(e))            await self.disconnect(run_id)    async def _handle_stream_error(self, run_id: int, error: Exception) -> None:        """        处理流错误并正确更新运行                参数:            run_id (int): 运行的ID            error (Exception): 发生的异常        """        # Handle stream errors with proper run updates        #         # Args:        #     run_id (int): ID of the run        #     error (Exception): Exception that occurred        if run_id not in self._closed_connections:            error_result = TeamResult(                task_result=TaskResult(                    messages=[TextMessage(source="system", content=str(error))],                    stop_reason="An error occurred while processing this run",                ),                usage="",                duration=0,            ).model_dump()            await self._send_message(                run_id,                {                    "type": "completion",                    "status": "error",                    "data": error_result,                    "timestamp": datetime.now(timezone.utc).isoformat(),                },            )            await self._update_run(                run_id, RunStatus.ERROR, team_result=error_result, error=str(error)            )    def _format_message(self, message: Any) -> Optional[Dict[str, Any]]:        """格式化用于WebSocket传输的消息                参数:            message (Any): 要格式化的消息                    返回:            Optional[Dict[str, Any]]: 格式化的消息,如果格式化失败则返回None        """        # Format message for WebSocket transmission        #         # Args:        #     message (Any): Message to format        #         # Returns:        #     Optional[Dict[str, Any]]: Formatted message or None if formatting fails        try:            if isinstance(message, MultiModalMessage):                message_dump = message.model_dump()                message_content: list[dict[str, Any]] = []                for row in message_dump["content"]:                    if "data" in row:                        message_content.append(                            {                                "url": f"data:image/png;base64,{row['data']}",                                "alt": "WebSurfer Screenshot",                            }                        )                    else:                        message_content.append(row)                message_dump["content"] = message_content                return {"type": "message", "data": message_dump}            elif isinstance(message, TeamResult):                return {                    "type": "result",                    "data": message.model_dump(),                    "status": "complete",                }            elif isinstance(message, ModelClientStreamingChunkEvent):                return {"type": "message_chunk", "data": message.model_dump()}            elif isinstance(                message,                (TextMessage,),            ):                return {"type": "message", "data": message.model_dump()}            elif isinstance(message, str):                return {                    "type": "message",                    "data": {"source": "user", "content": message},                }            return None        except Exception as e:            logger.error(f"Message formatting error: {e}")            return None    async def _get_run(self, run_id: int) -> Optional[Run]:        """从数据库获取运行                参数:            run_id (int): 要检索的运行ID                    返回:            Optional[Run]: 如果找到则返回Run对象,否则返回None        """        # Get run from database        #         # Args:        #     run_id (int): int of the run to retrieve        #         # Returns:        #     Optional[Run]: Run object if found, None otherwise        response = self.db_manager.get(Run, filters={"id": run_id}, return_json=False)        return response.data[0] if response.status and response.data else None    async def _get_settings(self, user_id: str) -> Optional[Settings]:        """从数据库获取用户设置                参数:            user_id (str): 要检索设置的用户ID                    返回:            Optional[Settings]: 如果找到则返回用户设置,否则返回None        """        # Get user settings from database        # Args:        #     user_id (str): User ID to retrieve settings for        # Returns:        #     Optional[Settings]: User settings if found, None otherwise        response = self.db_manager.get(            filters={"user_id": user_id}, model_class=Settings, return_json=False        )        return response.data[0] if response.status and response.data else None    async def _update_run_status(        self, run_id: int, status: RunStatus, error: Optional[str] = None    ) -> None:        """更新数据库中的运行状态                参数:            run_id (int): 要更新的运行ID            status (RunStatus): 要设置的新状态            error (str, optional): 可选的错误消息        """        # Update run status in database        #         # Args:        #     run_id (int): int of the run to update        #     status (RunStatus): New status to set        #     error (str, optional): Optional error message        run = await self._get_run(run_id)        if run:            run.status = status            run.error_message = error            self.db_manager.upsert(run)        # 向客户端发送带有状态的系统消息        # send system message to client with status        await self._send_message(            run_id,            {                "type": "system",                "status": status,                "timestamp": datetime.now(timezone.utc).isoformat(),            },        )    async def cleanup(self) -> None:        """服务器关闭时清理所有活动连接和资源"""        # Clean up all active connections and resources when server is shutting down        logger.info(f"Cleaning up {len(self.active_connections)} active connections")        try:            # 首先取消所有正在运行的任务            # First cancel all running tasks            for run_id in self.active_runs.copy():                if run_id in self._cancellation_tokens:                    self._cancellation_tokens[run_id].cancel()                run = await self._get_run(run_id)                if run and run.status == RunStatus.ACTIVE:                    interrupted_result = TeamResult(                        task_result=TaskResult(                            messages=[                                TextMessage(                                    source="system",                                    content="Run interrupted by server shutdown",                                )                            ],                            stop_reason="server_shutdown",                        ),                        usage="",                        duration=0,                    ).model_dump()                    run.status = RunStatus.STOPPED                    run.team_result = interrupted_result                    self.db_manager.upsert(run)            # 然后在超时时间内断开所有websocket连接            # 10秒超时用于整个清理过程            # Then disconnect all websockets with timeout            # 10 second timeout for entire cleanup            async def disconnect_all():                for run_id in self.active_connections.copy():                    try:                        await asyncio.wait_for(self.disconnect(run_id), timeout=2)                    except asyncio.TimeoutError:                        logger.warning(f"Timeout disconnecting run {run_id}")                    except Exception as e:                        logger.error(f"Error disconnecting run {run_id}: {e}")            await asyncio.wait_for(disconnect_all(), timeout=10)        except asyncio.TimeoutError:            logger.warning("WebSocketManager cleanup timed out")        except Exception as e:            logger.error(f"Error during WebSocketManager cleanup: {e}")        finally:            # 始终清除内部状态,即使清理出现错误            # Always clear internal state, even if cleanup had errors            self._connections.clear()            self._cancellation_tokens.clear()            self._closed_connections.clear()            self._input_responses.clear()    @property    def active_connections(self) -> set[int]:        """获取活动运行ID集合"""        # Get set of active run IDs        return set(self._connections.keys()) - self._closed_connections    @property    def active_runs(self) -> set[int]:        """获取具有活动取消令牌的运行集合"""        # Get set of runs with active cancellation tokens        return set(self._cancellation_tokens.keys())    async def pause_run(self, run_id: int) -> None:        """暂停运行"""        # Pause the run        if (            run_id in self._connections            and run_id not in self._closed_connections            and run_id in self._team_managers        ):            team_manager = self._team_managers.get(run_id)            if team_manager:                await team_manager.pause_run()                # await self._send_message(                #     run_id,                #     {                #         "type": "system",                #         "status": "paused",                #         "timestamp": datetime.now(timezone.utc).isoformat(),                #     },                # )                # await self._update_run_status(run_id, RunStatus.PAUSED)    async def resume_run(self, run_id: int) -> None:        """恢复运行"""        # Resume the run        if (            run_id in self._connections            and run_id not in self._closed_connections            and run_id in self._team_managers        ):            team_manager = self._team_managers.get(run_id)            if team_manager:                await team_manager.resume_run()                await self._update_run_status(run_id, RunStatus.ACTIVE)

connect

    async def connect(self, websocket: WebSocket, run_id: int) -> bool:        # 建立WebSocket连接        try:            await websocket.accept()            self._connections[run_id] = websocket            self._closed_connections.discard(run_id)            # 为此连接初始化输入队列            # Initialize input queue for this connection            self._input_responses[run_id] = asyncio.Queue()            await self._send_message(                run_id,                {                    "type": "system",                    "status": "connected",                    "timestamp": datetime.now(timezone.utc).isoformat(),                },            )            return True        except Exception as e:            logger.error(f"Connection error for run {run_id}: {e}")            return False

处理单个 WebSocket 客户端连接的注册与初始化:记录连接、创建输入队列、发送欢迎消息。

start_stream

    async def start_stream(        self,        run_id: int,        task: str | ChatMessage | Sequence[ChatMessage] | None,        team_config: Dict[str, Any],        settings_config: Dict[str, Any],        user_settings: Settings | None = None,    ) -> None:        """        开始流式任务执行并进行适当的运行管理                参数:            run_id (int): 运行的ID            task (str | ChatMessage | Sequence[ChatMessage] | None): 要执行的任务            team_config (Dict[str, Any]): 团队的配置            settings_config (Dict[str, Any]): 设置的配置            user_settings (Settings, optional): 运行的用户设置        """        # Start streaming task execution with proper run management        #         # Args:        #     run_id (int): ID of the run        #     task (str | ChatMessage | Sequence[ChatMessage] | None): Task to execute        #     team_config (Dict[str, Any]): Configuration for the team        #     settings_config (Dict[str, Any]): Configuration for settings        #     user_settings (Settings, optional): User settings for the run        if run_id not in self._connections or run_id in self._closed_connections:            raise ValueError(f"No active connection for run {run_id}")        # 如果已存在团队管理器,则不创建新的        # do not create a new team manager if one already exists        if run_id not in self._team_managers:            team_manager = TeamManager(                internal_workspace_root=self.internal_workspace_root,                external_workspace_root=self.external_workspace_root,                inside_docker=self.inside_docker,                config=self.config,            )            self._team_managers[run_id] = team_manager        else:            team_manager = self._team_managers[run_id]        cancellation_token = CancellationToken()        self._cancellation_tokens[run_id] = cancellation_token        final_result = None        try:            # 使用任务和状态更新运行            # Update run with task and status            run = await self._get_run(run_id)            assert run is not None, f"Run {run_id} not found in database"            assert run.user_id is not None, f"Run {run_id} has no user ID"            # 获取用户设置            # Get user Settings            user_settings = await self._get_settings(run.user_id)            env_vars = (                SettingsConfig(**user_settings.config).environment  # type: ignore                if user_settings                else None            )            settings_config["memory_controller_key"] = run.user_id            state = None            if run:                run.task = MessageConfig(content=task, source="user").model_dump()                run.status = RunStatus.ACTIVE                state = run.state                self.db_manager.upsert(run)                await self._update_run_status(run_id, RunStatus.ACTIVE)            # 将任务添加为消息            # add task as message            if isinstance(task, str):                await self._send_message(                    run_id,                    self._format_message(TextMessage(source="user_proxy", content=task))                    or {},                )                await self._save_message(                    run_id, TextMessage(source="user_proxy", content=task)                )            elif isinstance(task, Sequence):                for task_message in task:                    if isinstance(task_message, TextMessage) or isinstance(                        task_message, MultiModalMessage                    ):                        if (                            hasattr(task_message, "metadata")                            and task_message.metadata.get("internal") == "yes"                        ):                            continue                        await self._send_message(                            run_id, self._format_message(task_message) or {}                        )                        await self._save_message(run_id, task_message)            input_func: InputFuncType = self.create_input_func(run_id)            message: ChatMessage | AgentEvent | TeamResult | LLMCallEventMessage            async for message in team_manager.run_stream(                task=task,                team_config=team_config,                state=state,                input_func=input_func,                cancellation_token=cancellation_token,                env_vars=env_vars,                settings_config=settings_config,                run=run,            ):                if (                    cancellation_token.is_cancelled()                    or run_id in self._closed_connections                ):                    logger.info(                        f"Stream cancelled or connection closed for run {run_id}"                    )                    break                if isinstance(message, CheckpointEvent):                    # 保存状态到运行中                    # Save state to run                    run = await self._get_run(run_id)                    if run:                        # 使用compress_state工具压缩状态                        # Use compress_state utility to compress the state                        state_dict = json.loads(message.state)                        run.state = compress_state(state_dict)                        self.db_manager.upsert(run)                    continue                # 不显示内部消息                # do not show internal messages                if (                    hasattr(message, "metadata")                    and message.metadata.get("internal") == "yes"  # type: ignore                ):                    continue                formatted_message = self._format_message(message)                if formatted_message:                    await self._send_message(run_id, formatted_message)                    # 按具体类型保存消息                    # Save messages by concrete type                    if isinstance(                        message,                        (                            TextMessage,                            MultiModalMessage,                            StopMessage,                            HandoffMessage,                            ToolCallRequestEvent,                            ToolCallExecutionEvent,                            LLMCallEventMessage,                        ),                    ):                        await self._save_message(run_id, message)                    # 如果是TeamResult则捕获最终结果                    # Capture final result if it's a TeamResult                    elif isinstance(message, TeamResult):                        final_result = message.model_dump()                    self._team_managers[run_id] = team_manager  # 跟踪团队管理器            if (                not cancellation_token.is_cancelled()                and run_id not in self._closed_connections            ):                if final_result:                    await self._update_run(                        run_id, RunStatus.COMPLETE, team_result=final_result                    )                else:                    logger.warning(                        f"No final result captured for completed run {run_id}"                    )                    await self._update_run_status(run_id, RunStatus.COMPLETE)            else:                await self._send_message(                    run_id,                    {                        "type": "completion",                        "status": "cancelled",                        "data": self._cancel_message,                        "timestamp": datetime.now(timezone.utc).isoformat(),                    },                )                # 使用取消结果更新运行                # Update run with cancellation result                await self._update_run(                    run_id, RunStatus.STOPPED, team_result=self._cancel_message                )        except Exception as e:            logger.error(f"Stream error for run {run_id}: {e}")            traceback.print_exc()            await self._handle_stream_error(run_id, e)        finally:            self._cancellation_tokens.pop(run_id, None)            self._team_managers.pop(run_id, None)  # 完成后移除团队管理器

start_stream() 负责启动一个指定任务的 AI 团队流式运行流程:准备数据 → 执行 → 捕获输出 → 监控状态 → 推送结果。

检查连接是否存在
if run_id not in self._connections or run_id in self._closed_connections:    raise ValueError(f"No active connection for run {run_id}")

确保前端已经通过 WebSocket 成功连接。

准备 TeamManager 实例
if run_id not in self._team_managers:    team_manager = TeamManager(...)
准备取消控制器
cancellation_token = CancellationToken()self._cancellation_tokens[run_id] = cancellation_token
获取运行状态与用户配置
run = await self._get_run(run_id)user_settings = await self._get_settings(run.user_id)env_vars = SettingsConfig(...).environment
更新数据库运行状态并保存初始任务
run.task = ...run.status = RunStatus.ACTIVE

同时向前端推送消息:

await self._send_message(run_id, TextMessage(...))await self._save_message(run_id, ...)
创建输入函数,用于团队运行时交互
input_func: InputFuncType = self.create_input_func(run_id)

供运行中需要用户反馈的代理使用。

create_input_func

   def create_input_func(self, run_id: int, timeout: int = 600) -> InputFuncType:        """        为特定运行创建输入函数                参数:            run_id (int): 运行的ID            timeout (int, optional): 输入响应的超时时间(秒)。默认值:600        返回:            InputFuncType: 运行的输入函数        """        # Creates an input function for a specific run        #         # Args:        #     run_id (int): ID of the run        #     timeout (int, optional): Timeout for input response in seconds. Default: 600        # Returns:        #     InputFuncType: Input function for the run        async def input_handler(            prompt: str = "",            cancellation_token: Optional[CancellationToken] = None,            input_type: InputRequestType = "text_input",        ) -> str:            try:                # 如果运行已暂停则恢复运行                # resume run if it is paused                await self.resume_run(run_id)                # 将运行状态更新为等待输入                # update run status to awaiting_input                await self._update_run_status(run_id, RunStatus.AWAITING_INPUT)                # 向客户端发送输入请求                # Send input request to client                logger.info(                    f"Sending input request for run {run_id}: ({input_type}) {prompt}"                )                await self._send_message(                    run_id,                    {                        "type": "input_request",                        "input_type": input_type,                        "prompt": prompt,                        "data": {"source": "system", "content": prompt},                        "timestamp": datetime.now(timezone.utc).isoformat(),                    },                )                # 在Run对象中存储input_request                # Store input_request in the Run object                run = await self._get_run(run_id)                if run:                    run.input_request = {"prompt": prompt, "input_type": input_type}                    self.db_manager.upsert(run)                # 等待带有超时的响应                # Wait for response with timeout                if run_id in self._input_responses:                    try:                        async def poll_for_response():                            while True:                                # 检查运行是否已关闭/取消                                # Check if run was closed/cancelled                                if run_id in self._closed_connections:                                    raise ValueError("Run was closed")                                                                # 检查取消令牌是否已设置并已取消                                # Check if cancellation token is set and cancelled                                if cancellation_token and hasattr(cancellation_token, 'is_cancelled'):                                    try:                                        if cancellation_token.is_cancelled():                                            raise ValueError("Run was cancelled")                                    except Exception:                                        # 如果检查取消状态失败,则假定已取消                                        # If checking cancellation status fails, assume it's cancelled                                        raise ValueError("Run was cancelled")                                # 尝试使用短超时获取响应                                # Try to get response with short timeout                                try:                                    response = await asyncio.wait_for(                                        self._input_responses[run_id].get(),                                        timeout=min(timeout, 5),                                    )                                    await self._update_run_status(                                        run_id, RunStatus.ACTIVE                                    )                                    return response                                except asyncio.TimeoutError:                                    continue  # 继续检查关闭状态                        response = await asyncio.wait_for(                            poll_for_response(), timeout=timeout                        )                        return response                    except asyncio.TimeoutError:                        # 如果发生超时则停止运行                        # Stop the run if timeout occurs                        logger.warning(f"Input response timeout for run {run_id}")                        await self.stop_run(                            run_id,                            "Magentic-UI timed out while waiting for your input. To resume, please enter a follow-up message in the input box or you can simply type 'continue'.",                        )                        raise                else:                    raise ValueError(f"No input queue for run {run_id}")            except Exception as e:                logger.error(f"Error handling input for run {run_id}: {e}")                raise        return input_handler

为指定运行(run_id)创建一个异步函数,支持在代理任务中向用户发出输入请求、等待用户响应,并处理异常或超时情况

恢复运行并更新状态
await self.resume_run(run_id)await self._update_run_status(run_id, RunStatus.AWAITING_INPUT)
发送输入请求到前端
await self._send_message(    run_id,    {        "type": "input_request",        "input_type": input_type,        "prompt": prompt,        ...    },)
将请求记录到数据库(Run 对象中)
run.input_request = {"prompt": prompt, "input_type": input_type}self.db_manager.upsert(run)
等待用户输入(带超时)
if run_id in self._input_responses:

队列由 connect() 初始化,客户端消息处理器会通过此队列投递用户输入。

async def poll_for_response():    ...    response = await asyncio.wait_for(self._input_responses[run_id].get(), ...)

启动总超时等待

response = await asyncio.wait_for(poll_for_response(), timeout=timeout)

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Magnetic-UI 团队创建 任务执行 源码解析 流式处理
相关文章