TRPO算法实战:用Python手把手教你训练CartPole平衡AI(附完整代码)
TRPO算法实战从零构建CartPole平衡AI的工程指南在强化学习领域策略优化算法一直是实现智能控制的核心工具。而Trust Region Policy OptimizationTRPO作为其中的经典方法通过独特的数学约束机制解决了传统策略梯度算法中步长选择困难、训练不稳定的痛点。本文将完全从工程实践角度出发带你用Python和PyTorch实现一个完整的TRPO智能体并在经典的CartPole控制环境中验证其效果。1. 环境搭建与核心概念1.1 CartPole环境解析CartPole是OpenAI Gym中最经典的测试环境之一其状态空间包含4个维度小车位置x小车速度ẋ杆子角度θ杆子角速度θ̇动作空间则是离散的两种选择0向左施加力1向右施加力奖励机制非常简单每保持杆子直立一步获得1奖励当杆子倾斜超过15度或小车移动超出边界时回合终止。一个理想的策略应该能持续保持杆子平衡达到最大步长通常为200步。import gym env gym.make(CartPole-v1) state_dim env.observation_space.shape[0] # 4 action_dim env.action_space.n # 21.2 TRPO的核心创新与传统策略梯度方法相比TRPO引入了三个关键机制信任区域约束通过KL散度限制新旧策略的差异防止策略更新过大导致性能崩溃共轭梯度优化高效计算自然梯度方向避免直接求逆高维Fisher信息矩阵替代目标函数使用重要性采样构建可优化的替代目标解决策略变化导致数据分布偏移的问题这些机制共同作用使得TRPO在保持训练稳定性的同时能够实现较大的策略更新步长。下面是我们将要实现的算法框架初始化策略网络和价值网络 for 每次迭代: 1. 用当前策略收集轨迹数据 2. 计算各状态的优势估计GAE 3. 构建替代目标函数 4. 用共轭梯度法计算更新方向 5. 通过线性搜索确定最优步长 6. 更新策略网络参数 7. 更新价值网络参数2. 神经网络架构设计2.1 策略网络Actor策略网络接收环境状态输出动作的概率分布。对于CartPole这样的离散动作空间我们使用带Softmax输出的全连接网络import torch import torch.nn as nn import torch.nn.functional as F class PolicyNet(nn.Module): def __init__(self, state_dim, hidden_dim, action_dim): super(PolicyNet, self).__init__() self.fc1 nn.Linear(state_dim, hidden_dim) self.fc2 nn.Linear(hidden_dim, action_dim) def forward(self, x): x F.relu(self.fc1(x)) return F.softmax(self.fc2(x), dim1)2.2 价值网络Critic价值网络用于估计状态价值函数帮助计算优势函数。它是一个回归网络输出单个标量值class ValueNet(nn.Module): def __init__(self, state_dim, hidden_dim): super(ValueNet, self).__init__() self.fc1 nn.Linear(state_dim, hidden_dim) self.fc2 nn.Linear(hidden_dim, 1) def forward(self, x): x F.relu(self.fc1(x)) return self.fc2(x)2.3 网络初始化技巧为了训练稳定性建议采用以下初始化策略def init_weights(m): if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight, gain0.01) nn.init.constant_(m.bias, 0) actor PolicyNet(state_dim, 64, action_dim) critic ValueNet(state_dim, 64) actor.apply(init_weights) critic.apply(init_weights)正交初始化配合较小的增益值有助于保持初始策略的随机性避免过早陷入局部最优。3. 核心算法实现3.1 广义优势估计GAEGAE通过结合多步TD误差提供低方差、低偏差的优势估计def compute_advantage(gamma, lmbda, td_delta): td_delta td_delta.detach().numpy() advantage_list [] advantage 0.0 for delta in td_delta[::-1]: advantage gamma * lmbda * advantage delta advantage_list.append(advantage) advantage_list.reverse() return torch.tensor(advantage_list, dtypetorch.float)其中td_delta是时序差分误差δ_t r_t γ*V(s_{t1}) - V(s_t)3.2 共轭梯度法共轭梯度法用于高效求解自然梯度方向避免直接计算和存储Fisher信息矩阵def conjugate_gradient(actor, states, old_action_dists, b, max_iter10, residual_tol1e-10): x torch.zeros_like(b) r b.clone() p b.clone() rdotr torch.dot(r, r) for _ in range(max_iter): Hp fisher_vector_product(actor, states, old_action_dists, p) alpha rdotr / torch.dot(p, Hp) x alpha * p r - alpha * Hp new_rdotr torch.dot(r, r) if new_rdotr residual_tol: break beta new_rdotr / rdotr p r beta * p rdotr new_rdotr return x3.3 线性搜索线性搜索确保策略更新既提升性能又满足KL约束def line_search(actor, states, actions, advantage, old_log_probs, old_action_dists, full_step, max_backtracks15, accept_ratio0.1): old_params torch.nn.utils.convert_parameters.parameters_to_vector(actor.parameters()) old_obj surrogate_obj(actor, states, actions, advantage, old_log_probs) for step_frac in [0.5**i for i in range(max_backtracks)]: new_params old_params step_frac * full_step new_actor copy.deepcopy(actor) torch.nn.utils.convert_parameters.vector_to_parameters(new_params, new_actor.parameters()) new_obj surrogate_obj(new_actor, states, actions, advantage, old_log_probs) kl mean_kl_div(new_actor, states, old_action_dists) if new_obj old_obj accept_ratio * step_frac and kl kl_constraint: return new_params return old_params4. 完整训练流程4.1 超参数设置TRPO对超参数相对鲁棒但以下设置在实践中表现良好参数推荐值说明γ (gamma)0.99折扣因子λ (lambda)0.97GAE衰减系数KL约束0.01信任区域大小批大小5000每批采样步数策略LR-TRPO不直接使用学习率价值LR1e-3价值网络学习率隐藏层64网络隐藏单元数4.2 训练循环实现def train(env, agent, num_episodes, batch_size5000): return_list [] for i in range(num_episodes): # 数据收集 states, actions, rewards, next_states, dones [], [], [], [], [] state env.reset() episode_return 0 for _ in range(batch_size): action agent.take_action(state) next_state, reward, done, _ env.step(action) states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) dones.append(done) state next_state if not done else env.reset() episode_return reward # 转换为张量 states torch.FloatTensor(np.array(states)) actions torch.LongTensor(np.array(actions)).unsqueeze(1) rewards torch.FloatTensor(np.array(rewards)) next_states torch.FloatTensor(np.array(next_states)) dones torch.FloatTensor(np.array(dones)) # 计算优势 with torch.no_grad(): values agent.critic(states) next_values agent.critic(next_states) td_delta rewards gamma * next_values * (1 - dones) - values advantages compute_advantage(gamma, lmbda, td_delta) # 策略更新 agent.update(states, actions, advantages) # 记录训练过程 return_list.append(episode_return / (batch_size / 200)) # 标准化回报 print(fEpisode {i}, Return: {return_list[-1]:.1f}) return return_list5. 调试与性能优化5.1 常见问题排查训练初期回报不增长检查优势函数计算是否正确验证策略网络是否输出了合理的动作分布确保KL约束没有设置过小训练过程波动大尝试减小KL约束阈值增加批处理大小检查价值网络学习率是否过高最终性能不理想尝试调整GAE参数λ增加网络容量延长训练时间5.2 性能优化技巧向量化计算 使用PyTorch的批量操作替代循环大幅提升计算效率并行数据收集 使用多环境并行收集数据减少训练时间from multiprocessing import Pool def collect_episode(env_fn): env env_fn() state env.reset() episode [] for _ in range(1000): action agent.take_action(state) next_state, reward, done, _ env.step(action) episode.append((state, action, reward, next_state, done)) if done: break state next_state return episode自动微分优化 使用torch.autograd.functional计算高阶导数提高代码可读性和运行效率6. 进阶扩展6.1 连续动作空间适配对于连续动作空间环境如Pendulum需要对策略网络进行修改class ContinuousPolicyNet(nn.Module): def __init__(self, state_dim, hidden_dim, action_dim): super().__init__() self.fc1 nn.Linear(state_dim, hidden_dim) self.fc_mu nn.Linear(hidden_dim, action_dim) self.fc_std nn.Linear(hidden_dim, action_dim) def forward(self, x): x F.relu(self.fc1(x)) mu torch.tanh(self.fc_mu(x)) * 2 # 假设动作范围[-2,2] std F.softplus(self.fc_std(x)) 1e-5 return torch.distributions.Normal(mu, std)6.2 与其他算法结合TRPO可以与以下技术结合获得更好效果PPO-Clip使用裁剪替代KL约束简化实现Q-Prop结合值函数和策略梯度方法分布式训练使用Ape-X框架加速探索7. 实际应用建议环境设计确保奖励函数设计合理考虑添加适当的观察噪声对状态进行标准化处理训练策略先在小批量数据上验证算法实现使用Tensorboard记录训练曲线定期保存模型快照部署考量将PyTorch模型导出为ONNX格式量化模型减小推理开销添加安全约束防止异常行为在完成上述实现后你应该能看到CartPole环境中的回报曲线稳步上升最终达到最大步长200步。TRPO的强大之处在于其稳定性——即使超参数不是最优通常也能获得不错的结果这正是信任区域方法的核心优势。