引言
在游戏与模拟世界的交汇处,强化学习(Reinforcement Learning, RL)宛如一位神秘导师,教会智能体如何在未知环境中探索、试错并最终精通任务。深度 Q 网络(Deep Q-Network, DQN)作为强化学习领域的里程碑算法,成功将深度学习的表征能力与强化学习的价值迭代方法结合,使智能体能在复杂环境中做出精准决策。而 CartPole 游戏,这个经典的强化学习环境,如同一片完美的试验田,让初学者能清晰观察强化学习算法从懵懂到精通的全过程。今天,就让我们一起踏上这段旅程,利用 Trae(假设为强化学习框架或工具库)实现 DQN,让智能体学会平衡木杆!
I. 强化学习基础理论
强化学习的四大核心元素
强化学习的舞台由四个关键元素搭建而成。智能体(Agent)是决策者,身处环境(Environment)之中;环境为智能体提供状态(State),如同游戏中的关卡信息;智能体依据策略(Policy)选择动作(Action),例如在游戏中移动角色;环境根据动作给出奖励(Reward)并切换到新状态。以 CartPole 为例,状态包含小车位置、速度,木杆角度、角速度;动作是向左或向右推动小车;奖励为每保持木杆平衡一个时间步长获得 +1;智能体的目标是最大化累计奖励,即让木杆平衡尽可能久。
马尔可夫决策过程(MDP)的数学基石
强化学习的理论根基是马尔可夫决策过程(Markov Decision Process, MDP)。它假设环境状态转移满足马尔可夫性质:下一状态 ( s' ) 仅依赖当前状态 ( s ) 和动作 ( a ),与历史状态无关。状态转移概率 ( P(s'|s,a) ) 定义从状态 ( s ) 执行动作 ( a ) 后进入状态 ( s' ) 的概率;奖励函数 ( R(s,a) ) 给出对应动作的即时奖励期望。策略 ( \pi(a|s) ) 描述智能体在状态 ( s ) 选择动作 ( a ) 的概率分布。智能体的任务是在策略空间中搜索使期望累计奖励 ( \mathbb{E}[\sum_{t=0}^T \gamma^t r_t] ) 最大化的策略,其中 ( \gamma ) 是折扣因子,平衡短期与长期奖励。
Q 学习的核心思想
Q 学习(Q-Learning)是强化学习的价值迭代方法之一,核心在于构建 Q 表(Quality Table),记录状态 - 动作对的 Q 值,即执行动作 ( a ) 后从状态 ( s ) 开始的期望累计奖励。Q 学习基于贝尔曼方程(Bellman Equation)更新 Q 值:( Q(s,a) \leftarrow Q(s,a) + \alpha [r + \gamma \max_{a'} Q(s',a') - Q(s,a)] ),其中 ( \alpha ) 是学习率。智能体通过不断探索环境,逐步优化 Q 表,最终学会最优策略。
DQN 的突破性创新
深度 Q 网络(DQN)解决了传统 Q 学习在高维状态空间下的 Q 表爆炸问题。它引入深度神经网络作为 Q 函数的近似器,将状态向量映射到各动作的 Q 值。为稳定训练,DQN 提出两大创新:经验回放(Experience Replay)与目标网络(Target Network)。经验回放将智能体与环境交互产生的样本(状态、动作、奖励、下一状态)存储在回放缓冲区,训练时随机抽取小批次样本,打破样本相关性,减少方差;目标网络是 Q 网络的复制,定期更新参数,为 Q 值更新提供稳定的目标值,避免网络输出波动导致的训练发散。
Mermaid 图形总结
graph TD A[强化学习与 DQN 框架] --> B[强化学习四大元素] B --> C[智能体与环境] B --> D[状态、动作、奖励] A --> E[马尔可夫决策过程] E --> F[状态转移概率] E --> G[奖励函数与策略] A --> H[Q 学习算法] H --> I[Q 表与贝尔曼更新] A --> J[DQN 创新点] J --> K[深度神经网络近似 Q 函数] J --> L[经验回放] J --> M[目标网络]
强化学习与传统机器学习对比
特性 | 强化学习 | 传统机器学习(监督学习) |
---|---|---|
数据获取方式 | 通过智能体与环境交互产生 | 依赖预标注数据集 |
反馈类型 | 延迟奖励,稀疏反馈 | 立即损失函数反馈 |
智能体主动性 | 智能体主动探索环境 | 被动接受数据 |
优化目标 | 最大化长期累计奖励 | 最小化即时预测误差 |
训练过程 | 试错迭代,策略逐步改进 | 批量或在线梯度下降 |
II. CartPole 问题深度剖析
CartPole 环境的物理本质
CartPole 是强化学习领域的经典模拟环境,它模拟了一个简单物理系统:一根木杆通过铰链连接在小车上,小车能在水平轨道上左右移动;木杆初始时刻处于直立状态,智能体通过控制小车左右运动,目标是让木杆保持平衡尽可能久。其物理本质涉及经典力学中的倒立摆系统,包含小车位置 ( x )、速度 ( \dot{x} )、木杆角度 ( \theta )、角速度 ( \dot{\theta} ),以及重力 ( g )、木杆长度 ( l )、质量 ( m ) 等参数。系统动态遵循牛顿运动定律与刚体转动方程,小车的推力产生加速度,带动木杆旋转。
状态空间与动作空间
在 Gym 的 CartPole-v1 环境中,状态空间为连续四维向量 ( [x, \dot{x}, \theta, \dot{\theta}] ),其中 ( x ) 范围为 ( [-4.8, 4.8] ),( \theta ) 范围为 ( [-0.418 \text{ rad}, 0.418 \text{ rad}] ),速度和角速度无严格边界但会随系统动态变化;动作空间为离散二维 ( {0, 1} ),分别代表向左推小车(力 ( -10 \text{ N} ))和向右推小车(力 ( +10 \text{ N} ))。奖励机制简单直接,只要木杆未倾倒且小车未超出轨道边界,每时间步长奖励 +1,否则结束回合(episode),最大回合步长为 500。
CartPole 的挑战与解决思路
CartPole 问题的挑战在于状态空间连续且维度较高,直接应用传统 Q 学习的 Q 表方法不现实;木杆平衡是一个动态不稳定系统,小车的微小推力可能导致木杆快速倾倒,需要智能体精准控制。解决思路是利用 DQN 的深度神经网络近似 Q 函数,将连续状态映射到离散动作的 Q 值;通过经验回放稳定训练过程,避免相关样本导致的梯度估计偏差;设置合理的探索策略,如 (\epsilon)-贪心,在探索与利用之间取得平衡,逐步学会让木杆长期平衡的策略。
Mermaid 图形总结
graph TD A[CartPole 问题分析] --> B[物理系统本质] B --> C[倒立摆动力学] B --> D[状态变量与控制力] A --> E[状态与动作空间] E --> F[连续状态向量] E --> G[离散动作集] A --> H[问题挑战] H --> I[高维连续状态] H --> J[动态不稳定系统] A --> K[解决思路] K --> L[DQN 网络近似] K --> M[经验回放与探索策略]
CartPole 物理参数对比
参数 | 符号 | 典型值 | 单位 |
---|---|---|---|
重力加速度 | ( g ) | 9.8 | ( \text{m/s}^2 ) |
木杆长度 | ( l ) | 0.5 | m |
木杆质量 | ( m ) | 0.1 | kg |
小车质量 | ( M ) | 1.0 | kg |
推力 | ( F ) | ±10 | N |
摩擦系数 | ( \mu ) | 0.0005 | 无量纲 |
III. DQN 网络架构与 Trae 实现
DQN 网络结构设计
DQN 的核心是神经网络结构,它将 CartPole 的四维状态向量映射到两个可能动作的 Q 值。一个简洁而有效的网络架构包含输入层(4 个神经元)、两个隐藏层(各 64 个神经元,激活函数采用 ReLU)、输出层(2 个神经元,对应左右动作的 Q 值)。输入层接收状态向量 ( [x, \dot{x}, \theta, \dot{\theta}] ),第一隐藏层通过线性变换与 ReLU 激活提取初级特征,如小车位置与速度的关联、木杆角度与角速度的组合;第二隐藏层进一步融合特征,挖掘状态与动作价值的深层关系;输出层给出每个动作的预期累计奖励 Q 值。
Trae 中的 DQN 实现
在 Trae 框架中(假设为强化学习库),实现 DQN 网络需定义网络结构、构建经验回放缓冲区、设置目标网络更新机制以及定义训练循环。以下是核心代码片段及详细解释:
import torchimport torch.nn as nnimport torch.optim as optimimport numpy as npfrom collections import deque# 定义 DQN 网络结构class DQN(nn.Module): def __init__(self, input_dim, output_dim): super(DQN, self).__init__() self.fc1 = nn.Linear(input_dim, 64) # 输入层到第一隐藏层 self.fc2 = nn.Linear(64, 64) # 第一隐藏层到第二隐藏层 self.fc3 = nn.Linear(64, output_dim) # 第二隐藏层到输出层 def forward(self, x): x = torch.relu(self.fc1(x)) # ReLU 激活第一隐藏层 x = torch.relu(self.fc2(x)) # ReLU 激活第二隐藏层 return self.fc3(x) # 输出原始 Q 值,不采用激活函数# 初始化 DQN 网络与目标网络input_dim = 4 # CartPole 状态维度output_dim = 2 # 动作维度(左、右)policy_net = DQN(input_dim, output_dim)target_net = DQN(input_dim, output_dim)target_net.load_state_dict(policy_net.state_dict()) # 同步目标网络参数target_net.eval() # 目标网络设置为评估模式# 定义优化器与损失函数optimizer = optim.Adam(policy_net.parameters(), lr=0.001)loss_fn = nn.MSELoss()# 构建经验回放缓冲区class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def add(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): batch = np.random.choice(len(self.buffer), batch_size, replace=False) return [self.buffer[i] for i in batch] def __len__(self): return len(self.buffer)replay_buffer = ReplayBuffer(capacity=10000)
经验回放与目标网络机制
经验回放与目标网络是 DQN 稳定训练的两大支柱。在每次智能体与环境交互后,将产生的样本(状态、动作、奖励、下一状态、是否结束)存入经验回放缓冲区。训练时,随机从小批次样本中抽取,打破样本之间的时序相关性,减少梯度估计的方差,使网络更新更加稳定。目标网络则是策略网络的复制,定期(如每 1000 个训练步)从策略网络复制参数,为 Q 值更新提供稳定的目标值。在计算 Q 值目标时,使用目标网络评估下一状态的最大 Q 值,避免策略网络自身输出波动导致的训练发散。
# 训练循环中的经验回放与目标网络更新def train_dqn(policy_net, target_net, replay_buffer, batch_size, gamma): if len(replay_buffer) < batch_size: return # 经验回放样本不足时不训练 # 从小批次样本中抽取 batch = replay_buffer.sample(batch_size) states, actions, rewards, next_states, dones = zip(*batch) # 转换为张量 states = torch.tensor(states, dtype=torch.float32) actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(1) rewards = torch.tensor(rewards, dtype=torch.float32) next_states = torch.tensor(next_states, dtype=torch.float32) dones = torch.tensor(dones, dtype=torch.float32) # 计算当前 Q 值 current_q = policy_net(states).gather(1, actions) # 计算目标 Q 值 with torch.no_grad(): next_q = target_net(next_states).max(1)[0] target_q = rewards + gamma * next_q * (1 - dones) # 计算损失并反向传播 loss = loss_fn(current_q.squeeze(), target_q) optimizer.zero_grad() loss.backward() optimizer.step()# 目标网络参数更新函数def update_target_net(policy_net, target_net): target_net.load_state_dict(policy_net.state_dict())
探索策略与训练参数调试
智能体的探索策略决定其在环境中的试错方式。(\epsilon)-贪心策略是经典选择:以概率 ( \epsilon ) 随机选择动作,以概率 ( 1 - \epsilon ) 选择当前 Q 网络评估的最优动作。初始阶段设置较大的 ( \epsilon )(如 1.0),让智能体充分探索环境;随着训练进行,逐步衰减 ( \epsilon )(如每 1000 步衰减 0.995 倍),增加利用已学知识的比重,最终稳定在较小值(如 0.01),保持少量探索避免过早收敛。
训练参数调试是 DQN 优化的关键环节。学习率(lr)设置过高会导致网络更新过激,Q 值估计发散;过低则训练收敛缓慢。初始学习率可设为 0.001,根据训练曲线动态调整。折扣因子 ( \gamma ) 控制智能体对长期奖励的重视程度,通常设为 0.9 至 0.99。批次大小(batch_size)影响训练稳定性与效率,小批次(32)训练速度快但方差大,大批次(128)更稳定但计算开销大,64 是常用折中值。目标网络更新频率(如每 1000 步)要平衡新鲜参数与稳定目标之间的矛盾。
Mermaid 图形总结
graph TD A[DQN 网络与训练机制] --> B[DQN 网络结构] B --> C[输入层(4 维状态)] B --> D[两隐藏层(64 单元,ReLU)] B --> E[输出层(2 动作 Q 值)] A --> F[经验回放机制] F --> G[存储交互样本] F --> H[小批次随机采样] A --> I[目标网络机制] I --> J[定期复制参数] I --> K[稳定目标 Q 值] A --> L[探索策略] L --> M[ε-贪心策略] L --> N[参数衰减调度] A --> O[训练参数调试] O --> P[学习率与 γ 调整] O --> Q[批次大小与更新频率]
DQN 训练参数对比
参数 | 典型值 | 调试范围 | 影响 |
---|---|---|---|
学习率(lr) | 0.001 | 0.0001 - 0.01 | 决定 Q 值更新步长,影响收敛速度与稳定性 |
折扣因子(γ) | 0.99 | 0.9 - 0.999 | 控制长期奖励权重,值越高越注重未来收益 |
批次大小 | 64 | 32 - 128 | 影响方差与计算效率,小批次高方差、快更新 |
目标网络更新频率 | 1000 步 | 200 - 2000 步 | 决定目标网络参数新鲜度与稳定性平衡 |
ε 初始值 | 1.0 | 0.8 - 1.0 | 初始探索程度,值越高探索越充分 |
ε 衰减率 | 0.995 每千步 | 0.99 - 0.999 每百至千步 | 控制探索到利用的过渡速度 |
IV. DQN 训练过程深度解析
环境交互与样本收集
DQN 训练始于智能体与 CartPole 环境的交互。在每个回合(episode)开始时,环境重置为初始状态;智能体观察状态,依据当前策略(初始阶段多为随机探索)选择动作;环境执行动作,返回下一状态、奖励及是否结束信号;智能体将交互样本存储到经验回放缓冲区;当回合结束(木杆倾倒或达到最大步长)时,统计回合奖励(即平衡木杆的时间步数),记录到训练日志。这一过程循环往复,智能体在不断试错中积累经验。
# 环境交互与样本收集代码示例import gymenv = gym.make('CartPole-v1')num_episodes = 1000gamma = 0.99batch_size = 64epsilon = 1.0epsilon_decay = 0.995min_epsilon = 0.01rewards_history = [] # 记录每个回合的奖励for episode in range(num_episodes): state = env.reset() # 重置环境,获取初始状态 episode_reward = 0 done = False while not done: # ε-贪心策略选择动作 if np.random.random() < epsilon: action = np.random.randint(2) # 随机选择动作 else: with torch.no_grad(): state_tensor = torch.tensor(state, dtype=torch.float32) q_values = policy_net(state_tensor) action = q_values.argmax().item() # 选择 Q 值最大的动作 # 执行动作,获取环境反馈 next_state, reward, done, _ = env.step(action) episode_reward += reward # 存储样本到经验回放缓冲区 replay_buffer.add(state, action, reward, next_state, done) # 更新状态 state = next_state # 训练 DQN 网络 if len(replay_buffer) >= batch_size: train_dqn(policy_net, target_net, replay_buffer, batch_size, gamma) # 更新目标网络(简单起见,这里每回合更新一次) if episode % 10 == 0: update_target_net(policy_net, target_net) # 衰减 ε 值 epsilon = max(min_epsilon, epsilon * epsilon_decay) rewards_history.append(episode_reward) print(f"Episode {episode + 1}, Reward: {episode_reward}, Epsilon: {epsilon:.2f}")
训练收敛性分析
观察 DQN 训练过程中的回合奖励曲线,初期奖励波动剧烈且值较低,智能体主要通过随机探索与环境交互,偶尔碰巧找到能让木杆平衡数步的方法;中期奖励逐渐上升并趋于稳定,智能体学会基本平衡技巧,如通过小车左右摆动修正木杆倾角;后期奖励稳定在较高水平(接近 500,即达到环境最大步长),表明智能体已掌握稳定策略,能长期保持木杆平衡。若奖励曲线出现振荡或下降,可能暗示超参数设置不当,如学习率过高导致 Q 值估计发散,或目标网络更新频率过低使目标值滞后。
训练过程中的常见问题与解决方法
问题 | 可能原因 | 解决方法 |
---|---|---|
回合奖励长期停滞在低水平 | 探索不足,智能体陷入局部策略 | 增大 ε 初始值或减缓 ε 衰减速度,鼓励更多探索 |
Q 值估计发散,损失函数急剧上升 | 学习率过高或目标网络更新过频 | 降低学习率(如调至 0.0001),增加目标网络更新间隔 |
训练初期收敛快但后期性能下降 | 经验回放缓冲区容量不足,样本过时 | 增大缓冲区容量(如至 50000),或采用优先经验回放 |
智能体策略过于保守,仅能应对小扰动 | 网络容量不足,无法拟合复杂策略 | 增加隐藏层神经元数量(如至 128)或添加一层 |
Mermaid 图形总结
graph TD A[DQN 训练过程框架] --> B[环境交互循环] B --> C[重置环境获取初始状态] B --> D[ε-贪心选择动作] B --> E[执行动作获取反馈] B --> F[存储样本到回放缓冲区] A --> G[训练循环] G --> H[检查回放缓冲区样本量] G --> I[小批次采样与训练] G --> J[更新目标网络] A --> K[训练监控] K --> L[记录奖励曲线] K --> M[观察 Q 值与损失变化] A --> N[常见问题排查] N --> O[奖励停滞解决方法] N --> P[Q 值发散应对策略]
V. DQN 性能评估与优化
性能评估指标与方法
评估 DQN 在 CartPole 任务中的性能,主要关注以下指标:平均回合奖励(Average Episode Reward),取最近 100 个回合奖励均值,衡量智能体策略稳定性;成功回合率(Success Rate),统计在 100 个回合中达到最大步长(500 步)的回合比例,反映策略可靠性;Q 值估计合理性,观察 Q 值分布,若 Q 值持续攀升可能暗示过估计问题,需引入双 DQN(Double DQN)等改进方法。
可视化分析 Q 值与策略
通过可视化 Q 值与智能体策略,能深入理解 DQN 学到了什么。绘制状态空间中部分典型状态(如不同木杆角度、小车位置组合)对应的 Q 值热力图,Q 值越高表示智能体认为在该状态下执行对应动作越有利。例如,当木杆向右倾斜(θ > 0),Q 值热力图显示在该状态区域,向右推动小车(动作 1)的 Q 值普遍高于向左推动(动作 0),表明智能体学会通过小车向右移动来平衡木杆。同时,可录制智能体执行策略的视频,直观观察其平衡木杆的动作模式,如是否主动修正木杆倾角、小车移动幅度是否合理等。
# Q 值热力图可视化代码示例import matplotlib.pyplot as pltimport numpy as np# 生成状态网格(简化为木杆角度与小车位置二维状态)theta_range = np.linspace(-0.418, 0.418, 50)x_range = np.linspace(-4.8, 4.8, 50)Q_values = np.zeros((len(x_range), len(theta_range), 2)) # 记录各状态动作 Q 值for i, x in enumerate(x_range): for j, theta in enumerate(theta_range): # 构造完整状态(其余状态变量设为 0) state = torch.tensor([x, 0.0, theta, 0.0], dtype=torch.float32) with torch.no_grad(): q = policy_net(state) Q_values[i, j, :] = q.numpy()# 绘制动作 0(向左)的 Q 值热力图plt.figure(figsize=(10, 6))X, THETA = np.meshgrid(x_range, theta_range)plt.contourf(X, THETA, Q_values[:, :, 0].T, levels=50, cmap='viridis')plt.colorbar(label='Q Value (Action Left)')plt.xlabel('Cart Position')plt.ylabel('Pole Angle')plt.title('Q Values for Action Left Across State Space')plt.show()
DQN 优化进阶技巧
若基础 DQN 的表现仍不理想,可引入多种优化技巧提升性能。双 DQN(Double DQN)通过用策略网络选择动作、目标网络评估 Q 值,减少 Q 值过估计问题;优先经验回放(Prioritized Experience Replay)根据样本的 TD 误差赋予不同样本不同采样概率,高频采样重要样本,提高样本利用效率; Dueling DQN 将网络输出分解为价值函数(Value Function)与优势函数(Advantage Function),更精准地更新 Q 值;分布强化学习(Distributional RL)直接建模 Q 值分布而非期望值,提升模型对不确定性建模能力。
# 双 DQN 训练修改代码片段def train_double_dqn(policy_net, target_net, replay_buffer, batch_size, gamma): if len(replay_buffer) < batch_size: return batch = replay_buffer.sample(batch_size) states, actions, rewards, next_states, dones = zip(*batch) states = torch.tensor(states, dtype=torch.float32) actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(1) rewards = torch.tensor(rewards, dtype=torch.float32) next_states = torch.tensor(next_states, dtype=torch.float32) dones = torch.tensor(dones, dtype=torch.float32) current_q = policy_net(states).gather(1, actions) # 双 DQN 的核心:用策略网络选动作,目标网络评估 Q 值 with torch.no_grad(): next_actions = policy_net(next_states).argmax(1) # 策略网络选择动作 next_q = target_net(next_states).gather(1, next_actions.unsqueeze(1)) # 目标网络评估 Q 值 target_q = rewards + gamma * next_q.squeeze() * (1 - dones) loss = loss_fn(current_q.squeeze(), target_q) optimizer.zero_grad() loss.backward() optimizer.step()
Mermaid 图形总结
graph TD A[DQN 性能评估与优化] --> B[性能评估指标] B --> C[平均回合奖励] B --> D[成功回合率] B --> E[Q 值合理性检查] A --> F[可视化分析] F --> G[Q 值热力图] F --> H[策略执行录像] A --> I[优化方法] I --> J[双 DQN] I --> K[优先经验回放] I --> L[Dueling DQN] I --> M[分布强化学习]
DQN 优化方法对比
优化方法 | 主要改进点 | 适用场景 | 实现复杂度 |
---|---|---|---|
双 DQN | 分离动作选择与 Q 值评估,减少过估计 | Q 值易过估计的任务 | 中等 |
优先经验回放 | 根据 TD 误差采样重要样本 | 样本效率低的任务 | 高 |
Dueling DQN | 分解 Q 值为价值与优势函数 | 状态价值差异大的任务 | 中 |
分布强化学习 | 建模 Q 值分布而非期望值 | 不确定性高的任务 | 高 |
VI. DQN 部署实战:从训练到推理
推理服务架构设计
将训练好的 DQN 模型部署为推理服务,需设计合理架构。在服务器端,采用 Flask 或 FastAPI 构建 RESTful API,接收客户端发送的 CartPole 状态数据;加载训练好的 DQN 模型权重,预处理输入状态(归一化等);调用模型进行前向传播,获取动作 Q 值;选择 Q 值最大的动作作为推荐控制信号返回客户端。部署时可采用 Docker 容器化服务,确保环境隔离与可移植性;在多线程或异步模式下运行 API,提升服务并发能力。
# 推理服务代码示例(使用 FastAPI)from fastapi import FastAPIimport uvicornimport torchimport numpy as npapp = FastAPI()# 加载 DQN 模型model = DQN(input_dim=4, output_dim=2)model.load_state_dict(torch.load("dqn_cartpole.pth"))model.eval()@app.post("/predict")async def predict(state: list): # 预处理状态数据(示例状态:[x, x_dot, theta, theta_dot]) state_tensor = torch.tensor(state, dtype=torch.float32) with torch.no_grad(): q_values = model(state_tensor) action = q_values.argmax().item() return {"action": action}if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)
模型保存与加载细节
在 DQN 训练完成后,保存模型的推荐方法是仅保存模型状态字典(state_dict),而非整个模型对象,这样可提升加载效率与兼容性。保存时执行 torch.save(policy_net.state_dict(), "dqn_cartpole.pth")
;加载时先初始化网络结构,再调用 model.load_state_dict(torch.load("dqn_cartpole.pth"))
。若在不同设备(如训练在 GPU、推理在 CPU)间迁移模型,加载时需指定映射:`model.load_state_dict(torch.load("dqn_cartpole.pth", map_location=torch.device('cpu')))。
推理延迟优化策略
推理延迟是部署服务的关键性能指标。在模型层面,可尝试减小网络规模(如将隐藏层神经元减至 32),以牺牲一定性能为代价提升推理速度;采用模型量化技术,将 32 位浮点数权重量化为 8 位整数,减少计算量与内存占用;在代码层面,优化输入预处理流程,避免不必要的数据拷贝与格式转换;使用 GPU 进行推理时,确保输入数据在 GPU 内存中,并利用 CUDA 图(CUDA Graph)记录与重放推理步骤,减少内核启动开销。
# 模型量化代码示例from torch.ao.quantization import get_default_qconfig, quantize_dynamic# 动态量化 DQN 模型qconfig = get_default_qconfig("fbgemm")quantized_model = quantize_dynamic( model, # 待量化的模型 {nn.Linear}, # 量化线性层 dtype=torch.qint8, qconfig_spec=qconfig)# 保存量化模型torch.save(quantized_model.state_dict(), "dqn_quantized.pth")