用《Flappy Bird》游戏带你搞懂强化学习:从Q-learning到DQN的保姆级实战
用《Flappy Bird》游戏带你搞懂强化学习从Q-learning到DQN的保姆级实战还记得2014年那个让人又爱又恨的《Flappy Bird》吗这只像素小鸟曾让无数玩家抓狂现在我们将用这个经典游戏作为实验室带你亲手打造一个会自己玩游戏的AI。这不是普通的编程教程而是一场从零开始的强化学习探险——不需要高深的数学基础只要你会写Python就能在3小时内见证AI从菜鸟到高手的进化历程。1. 环境搭建与游戏机制解析在PyGame中重建Flappy Bird只需要不到100行代码但这个简单游戏蕴含着强化学习的绝佳教学场景。让我们先拆解游戏的核心机制import pygame import random # 初始化游戏 pygame.init() screen pygame.display.set_mode((400, 600)) clock pygame.time.Clock() # 小鸟物理参数 bird_y 300 bird_velocity 0 gravity 0.25 flap_strength -5游戏状态可以简化为三个关键参数垂直距离小鸟与下一个管道开口中心的垂直差值水平距离小鸟与下一个管道开口的水平距离当前速度小鸟的瞬时垂直速度提示在强化学习中状态设计直接影响训练效果。过于简单的状态表示可能导致AI无法学习复杂策略。我们设计的奖励函数如下表所示事件即时奖励说明存活一帧0.1鼓励延长生存时间通过管道1主要目标奖励撞击障碍-1000强烈惩罚终止行为超出边界-1000防止逃避策略2. Q-learning实战从零构建决策表格Q-learning的核心是构建一个决策手册——Q表格它记录了在特定状态下采取某个动作的长期价值。对于我们的Flappy Birdimport numpy as np # 离散化状态空间 vertical_bins np.linspace(-200, 200, 20) # 垂直距离分20档 horizontal_bins np.linspace(0, 400, 20) # 水平距离分20档 velocity_bins np.linspace(-8, 8, 10) # 速度分10档 # 初始化Q表格 (状态1 × 状态2 × 状态3 × 动作) q_table np.zeros((20, 20, 10, 2))训练过程中的关键参数配置# 超参数设置 LEARNING_RATE 0.1 DISCOUNT_FACTOR 0.95 EPISODES 10000 epsilon 1.0 # 初始探索率 EPSILON_DECAY 0.9995训练循环的核心逻辑for episode in range(EPISODES): state env.reset() done False while not done: # ε-greedy策略 if random.random() epsilon: action random.randint(0, 1) # 随机探索 else: action np.argmax(q_table[state]) # 利用已知知识 next_state, reward, done env.step(action) # Q值更新公式 current_q q_table[state (action,)] max_next_q np.max(q_table[next_state]) new_q current_q LEARNING_RATE * (reward DISCOUNT_FACTOR * max_next_q - current_q) q_table[state (action,)] new_q state next_state epsilon * EPSILON_DECAY # 衰减探索率注意Q-learning面临维度灾难——当状态变量增加或精度要求提高时Q表格会指数级膨胀。这就是我们需要深度强化学习的原因。3. DQN进阶用神经网络替代Q表格Deep Q-Network (DQN) 用神经网络参数化Q函数解决了状态空间爆炸问题。我们使用PyTorch构建一个简单的CNNimport torch import torch.nn as nn class DQN(nn.Module): def __init__(self, input_shape, n_actions): super(DQN, self).__init__() self.conv nn.Sequential( nn.Conv2d(input_shape[0], 32, kernel_size8, stride4), nn.ReLU(), nn.Conv2d(32, 64, kernel_size4, stride2), nn.ReLU(), nn.Conv2d(64, 64, kernel_size3, stride1), nn.ReLU() ) conv_out_size self._get_conv_out(input_shape) self.fc nn.Sequential( nn.Linear(conv_out_size, 512), nn.ReLU(), nn.Linear(512, n_actions) ) def _get_conv_out(self, shape): o self.conv(torch.zeros(1, *shape)) return int(np.prod(o.size())) def forward(self, x): conv_out self.conv(x).view(x.size()[0], -1) return self.fc(conv_out)DQN引入了两个关键技术改进经验回放(Experience Replay)class ReplayBuffer: def __init__(self, capacity): self.buffer deque(maxlencapacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.buffer, batch_size)目标网络(Target Network)target_net DQN(input_shape, n_actions).to(device) target_net.load_state_dict(policy_net.state_dict()) target_update_counter 04. 训练技巧与性能优化在实际训练中我们发现几个关键技巧能显著提升表现学习率调度optimizer torch.optim.Adam(model.parameters(), lr0.0001) scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size10000, gamma0.9)帧堆叠(Frame Stacking)class FrameStack: def __init__(self, env, k): self.env env self.k k # 堆叠帧数 self.frames deque([], maxlenk) def reset(self): obs self.env.reset() for _ in range(self.k): self.frames.append(obs) return self._get_obs() def step(self, action): obs, reward, done, info self.env.step(action) self.frames.append(obs) return self._get_obs(), reward, done, info def _get_obs(self): return np.concatenate(list(self.frames), axis0)Double DQN (DDQN)改进# 普通DQN的Q值计算 q_values policy_net(states).gather(1, actions) # DDQN的Q值计算 with torch.no_grad(): next_actions policy_net(next_states).max(1)[1].unsqueeze(1) next_q_values target_net(next_states).gather(1, next_actions)训练过程中的典型性能指标变化训练阶段平均得分最大得分存活时间初期(0-1k)1.2530帧中期(1k-5k)15.742500帧后期(5k)68.3∞2000帧5. 可视化分析与调试技巧理解AI决策过程的关键是可视化。我们开发了几个调试工具Q值热力图def plot_q_values(state): q_values model(torch.FloatTensor(state).unsqueeze(0)) plt.imshow(q_values.detach().numpy(), cmaphot, interpolationnearest) plt.colorbar() plt.show()策略轨迹回放def replay_episode(model, env): state env.reset() frames [] done False while not done: frames.append(env.render(modergb_array)) action model.act(state) state, _, done, _ env.step(action) return frames常见问题排查指南AI完全不学习检查奖励函数设计验证梯度是否在更新调整学习率和折扣因子表现波动大增大经验回放缓冲区降低探索率衰减速度尝试DDQN结构过拟合当前环境引入随机初始条件使用课程学习策略添加正则化项在Colab笔记本上运行完整代码后你会看到AI从最初的随机乱飞到最终能无限生存的完整进化过程。有趣的是AI往往会发展出与人类不同的策略——比如紧贴管道上沿飞行以减少垂直移动幅度。