PyTorch强化学习实战(18)——强化学习训练加速

开发过程中有些细节容易被忽略,今天挑几个重点聊一聊。

PyTorch强化学习实战(18)——强化学习训练加速

0. 前言

咱们已经学习了若干提升深度Q网络 (Deep Q-Network, DQN)办法稳定性与收敛速度的实用技巧。这些技巧通过对经典 DQN 办法进行改良(举个例子向网络注入噪声或展开贝尔曼方程),以更短的训练时间获得更优策略。而本节咱们将探索另一种加速路径:通过调整办法实现细节来提升训练效率。这虽属纯工程优化范畴,却因其实用价值而至关重要。

1. 训练速度的重要性

首先,咱们探讨速度优化的意义。人工智能 (Artificial Intelligence, AI) 与机器学习 (Machine Learning, ML) 的进步始终由数据可用性和计算力提升驱动。假设某项需单机运行一个月的计算任务,若速度提升 5 倍,漫长等待将缩短至 6 天;100 倍加速则意味着原本耗时一个月的运算 8 小时即可做好。
这种提升不仅发生在高性能计算领域,更已渗透各个角落。现代微控制器的性能已堪比 15 年前的台式机。至于配备四至八核 CPUGPU 及数 GB 内存的现代智能手机,更是不言自明的算力典范。
另一方面,加速的必要性或许并不直观。有人会说:一个月也不算长,把电脑自行执行就好了。但思考整个计算任务的准备与执行过程会发现,即便是简单机器学习问题,首次尝试就完美实现几乎不可能——通常需要多次试运行来调整超参数、修复漏洞,才能最终部署。物理模拟、强化学习研究、大数据处理乃至广义编程皆遵循相同规律。于是加速不仅能缩短单次运行时间,更能加快迭代速度,从而显著提升整体效率与最终成果质量。
优化还有另一重意义:它能拓展我们处理问题的规模。方法加速可体现为两种形式:更快获得结果,或处理更复杂问题(如提高精度、减少现实简化、纳入更多数据)。
那么强化学习 (Reinforcement Learning, RL)方法如何从加速中受益。首先,即便最先进的 RL 方法样本效率仍不高,举个例子 Atari 游戏需与环境交互数百万次才能习得优秀策略,意味着数周训练周期。适度加速既能更快获取结果,也能开展更多实验以优化超参数。更重要的是,代码效率提升后,我们能挑战更复杂的问题。
在现代强化学习领域,Atari 游戏已经被认为是解决的问题,即使是困难探索游戏,像《蒙特祖马的复仇》,也能训练出超越人类水平的精度。于是研究前沿正转向观测空间与动作空间更复杂的难题,这必然需要更长的训练周期与更强的硬件支持,举个例子蛋白质折叠预测( AlphaFold 系统)和大语言模型 (Large Language Model, LLM)等更具挑战性的课题。
所有性能优化都建立在核心方法有效运行的基础上(而 RLML 领域的方法有效性往往难以直观判断),拥有一个运行缓慢但正确的程序总比有一个速度飞快却漏洞百出的程序要好。

2. 基准方案

在本节中,我们将以 Atari Pong 游戏环境为例,尝试加速其收敛速度。我们将以简单 DQN 模型作为基准方案,为评估改进效果,我们将采用两项关键指标:

  • 环境交互帧率 (FPS)。该指标表示训练过程中每秒从环境获取的帧数,反映系统与环境的通信效率。强化学习论文中通常会统计智能体在训练期间观测的总帧数,通常值为 2500 万至 5000 万帧。若 FPS=200,则做好 2500 万帧训练需约 2.89 天。计算时需要注意:学术论文通常统计原始环境帧数,但如果使用了帧跳过技术(通常跳 4 帧),则帧数需要除以该系数。在本节中,我们的测量采用智能体与环境交互次数为基准,故"原始环境 FPS"应为当前值的 4
  • 游戏通关耗时。当最近 100 回合的平滑奖励达到 18 分(Pong游戏最高分为 21 分)时停止训练。该阈值可调高,但 18 分通常表明智能体已基本掌握游戏策略,后续只需持续训练即可精益求精。之于是监测实际耗时,是因为仅凭 FPS 不足以全面反映训练加速效果
由于代码优化可能大幅提升 FPS 却影响收敛性能,且训练过程本身具有随机性(即使为 PyTorchGymNumPy 设置固定随机种子,后续步骤采用的并行化技术仍会引入不可避免的随机因素),单一 FPS 值不能可靠评估改进效果。最佳实践是多次运行基准测试取平均值——任何单次测试结果都不足以作为决策依据。 为了降低随机性的影响,本节所有数据均来自 5 次重复实验的平均值,所有基准测试均在相同硬件配置下进行。 首个基准测试是基础版本( baseline.py,完整代码参考 DQN 一节),训练过程中,代码会向 TensorBoard 记录以下指标:
  • reward:单回合未折扣的原始奖励值,横轴为回合序号
  • avg_reward:采用
α

=

0.98

α=0.98

α=0.98 滑动平均平滑处理的奖励值

  • steps:每回合持续的步骤数。通常,在开始时,智能体学习得很慢,每回合约 1000 步;随着策略改进,步数会增至 3000-4000 步并伴随奖励提升;但当智能体完全掌握游戏后,由于折扣因子

γ

γ

γ 促使策略追求速胜,步数会回落到 2000 步左右

  • loss:每 100 次训练迭代采样的损失值,通常维持在

2

×

10

3

2×10^{-3}

2×10−3 至

1

×

10

2

1×10^{-2}

1×10−2 区间。当智能体发现新行为导致Q值与实际奖励出现偏差时,损失值会出现短暂跃升

  • avg_loss:经平滑处理的损失值
  • epsilon:当前

ε

ε

ε 值——随机动作选择概率

  • avg_fps:经滑动平均处理的智能体与环境交互速度(观测次数/秒)

下图展示了 5 次基准测试的平均结果。

3. PyTorch 中的计算图

我们的首个优化算法并非直接加速基准模型,而是揭示一个常见却容易被忽视的性能陷阱。PyTorch 的梯度计算机制会对张量操作构建计算图,当调用最终损失的 backward() 方法时,自动计算模型所有参数的梯度。
这一机制在传统监督学习中运行良好,但强化学习代码通常更为复杂。当前训练的 RL 模型同时承担着为智能体提供环境行动决策的任务,而目标网络 (target network)进一步增加了复杂度。DQN 中的神经网络通常涉及三种场景:

1. 计算预测Q值:用于获取与贝尔曼方程估算的参考Q值之间的损失
2. 目标网络应用:获取下一状态的Q值以计算贝尔曼近似值
3. 动作决策:智能体选择待执行动作

其中仅第一种场景需要梯度计算。在第二种场景中,我们通过对目标网络返回张量显式调用 detach() 来阻断梯度——这一操作至关重要,它能防止梯度"从意外方向"流入模型,缺失该操作可能导致 DQN 完全无法收敛。第三种场景则通过将网络输出转为 NumPy 数组来阻断梯度。
虽然使用阻断梯度的代码能正常运行,但我们忽略了一个细节:以上三种场景都会创建计算图。尽管未调用 backward()PyTorch 仍会消耗资源(速度与内存)构建这些无用计算图。此时最佳解决方案是使用装饰器 torch.no_grad()
Python 装饰器是一个功能强大的特性,我们通过以下示例简要说明其用法,首先定义两个函数:

import torch
@torch.no_grad
def fun_a(t):
    return t*2

def fun_b(t):
    return t*2

这两个函数的功能相同——都将输入参数翻倍,但第一个函数使用了 torch.no_grad() 装饰器(该装饰器会临时禁用函数内所有张量的梯度计算),第二个则是普通函数。如示例所示,虽然输入张量t需要梯度 (requires_grad=True),但经过装饰的函数 fun_a 的输出结果不会携带梯度:

t = torch.ones(3, requires_grad=True)
t
# tensor([1., 1., 1.], requires_grad=True)
a = fun_a(t)
b = fun_b(t)
b
# tensor([2., 2., 2.], grad_fn=)
a
# tensor([2., 2., 2.])

但梯度禁用效果仅作用于被装饰函数内部:

a*t
tensor([2., 2., 2.], grad_fn=)

torch.no_grad() 函数也能够作为上下文管理器使用,用于在某段代码中停止梯度计算:

with torch.no_grad():
    c = t*2
c
# c = t*2

该功能提供了一种非常便捷的方式,能够明确指定代码中需要完全排除在梯度计算体系之外的部分。
为了验证不必要计算图带来的性能影响,我们在 slow_grads.py 中修改基准代码,其功能完全相同,但在智能体和损失计算部分移除了 torch.no_grad()。下图展示了这一改动带来的影响:

如图所示,性能损耗并不大(约 10 FPS),但对于结构更复杂的大型网络,差异可能会更显著。

4. 多环境并行

加速深度学习的常规思路是增大批大小 (batch size),这在深度强化学习中同样适用,但需谨慎处理。监督学习中"批大小越大越好"的原则通常成立——只需根据 GPU 内存尽可能增大批次,借助 GPU 强大的并行计算能力,更大批次意味着单位时间内能处理更多样本。但强化学习的情况略有不同。训练过程中同时发生两个关键过程:

1. 网络通过当前数据优化预测能力
2. 智能体持续探索环境

随着智能体不断探索环境并学习其行为结果,训练数据会动态变化。以射击游戏为例:初期智能体可能只会随机移动,被怪物击中,在训练缓冲区中积累大量"处处受敌"的负面经验;但当它偶然发现可用武器后,训练数据将发生质的飞跃。强化学习的收敛性依赖于训练与探索之间的微妙平衡——若仅增大批次而不调整其他参数,极易导致对当前数据的过拟合(例如,在射击游戏中,智能体可能形成"游戏尽早结束是减轻痛苦的唯一方式"的错误认知,永远无法发现随身武器)。
于是,在 n_envs.py 的中,我们的智能体使用多个相同环境的副本来收集训练数据。在每次训练迭代中,智能体会从所有并行环境中收集样本填充经验回放池,并按比例扩大批处理数据的规模。这种方式还能略微加快推理速度,因为我们能够通过神经网络的一次前向传播,同时为所有 N 个环境生成要执行的动作决策。
具体实现仅需对代码进行少量修改:

  • ExperienceSource 实例传入 NGym 环境副本
  • DQNAgent 已内置批量推理优化
为此我们修改了部分代码。生成批数据的函数会在每次训练迭代时执行多步操作(步数等于环境总数):

def batch_generator(buffer: lib.experience.ExperienceReplayBuffer,
                    initial: int, batch_size: int, steps: int):
    buffer.populate(initial)
    while True:
        buffer.populate(steps)
        yield buffer.sample(batch_size)

经验源接受环境数组而非单一环境:

envs = [
        lib.common.wrappers.wrap_dqn(gym.make(params.env_name))
        for _ in range(args.envs)
    ]
    params.batch_size *= args.envs

    exp_source = lib.experience.ExperienceSourceFirstLast(
        envs, agent, gamma=params.gamma)

其余修改仅为调整常量的微调,包含更新 FPS 追踪器和补偿 epsilon 衰减速度(随机步数比例)。由于环境数量成为需要调优的新超参数,我们针对 N=2...6 进行了多组实验。下图展示了平均动态:

从图中可以看出,新增一个训练环境可使 FPS 提升 47% (从 227 帧/秒增至 335 帧/秒),并加快约 10% 的收敛速度(从 52 分钟缩短至 48 分钟)。引入第三个环境时同样呈现增益效果( 398 帧/秒,进一步缩短至 36 分钟),但继续增加环境数量虽能进一步提高 FPS,却会对收敛速度产生负面影响。因此,超参数 N=3 大致可视为最优值,我们也可以自行调整实验。这也解释了为何我们不仅要监控原始FPS指标,还需关注智能体解决游戏的实际速度。

5. 在不同进程中进行游戏和训练

训练流程包含以下循环步骤:

1. 由当前网络决策动作,并在我们的环境数组中执行这些动作。
2. 将观测数据存入经验回放缓冲区
3. 从经验回放缓冲区中随机采样训练批次
4. 基于该批次进行训练

前两步旨在用环境样本(含观测值、动作、奖励及下一观测值)填充经验回放缓冲区,后两步则专注于网络训练。下图展示了上述步骤的运作机制,能更直观地呈现潜在的并行处理能力:左侧训练流程中,实线表示数据与代码流向,虚线代表神经网络在训练与推理中的应用,整个流程涉及训练环境、回放池和神经网络三大组件的协同运作。

如图所示,前两个步骤仅通过经验回放池和神经网络与底层进行通信。这种架构使得将这两个部分分离到不同的并行进程成为可能。下图展示了该方案的框架:

Pong 环境中,这种分离看似增加了不必要的代码复杂度,但在某些场景下却至关重要。假设需要处理一个计算缓慢的环境——每个步骤都需要数秒运算时间(例如"学习奔跑"、"假肢挑战"和"学习移动"项目,都涉及运行迟缓的神经肌肉模拟器)。这类情况下,务必将经验收集与训练过程分离。此时可通过部署多个并发环境,共同向中央训练流程输送经验数据。
要实现从串行代码到并行代码的转换,需要进行若干改造,修改部分在文件 parallel.py 中进行实现。

(1) 首先,使用 torch.multiprocessing 模块作为 Python 标准 multiprocessing 模块的替代方案:

import torch.multiprocessing as mp

NAME = "parallel"

@dataclass
class EpisodeEnded:
    reward: float
    steps: int
    epsilon: float

def play_func(params: common.Hyperparams, net: dqn_model.DQN,
              dev_name: str, exp_queue: mp.Queue):
    env = gym.make(params.env_name)
    env = lib.common.wrappers.wrap_dqn(env)
    device = torch.device(dev_name)

    selector = lib.actions.EpsilonGreedyActionSelector(epsilon=params.epsilon_start)
    epsilon_tracker = common.EpsilonTracker(selector, params)
    agent = lib.agent.DQNAgent(net, selector, device=device)
    exp_source = lib.experience.ExperienceSourceFirstLast(
        env, agent, gamma=params.gamma)

    for frame_idx, exp in enumerate(exp_source):
        epsilon_tracker.frame(frame_idx//2)
        exp_queue.put(exp)
        for reward, steps in exp_source.pop_rewards_steps():
            ee = EpisodeEnded(reward=reward, steps=steps, epsilon=selector.epsilon)
            exp_queue.put(ee)

标准库中的多进程模块提供了多种基础组件(如分布式队列 mp.Queue、子进程 mp.Process 等)来管理跨进程代码执行。PyTorch 在此基础之上进行了封装升级,实现了无需复制的跨进程张量共享机制——CPU 张量通过共享内存实现,GPU 张量则通过 CUDA 引用传递。这种共享设计有效消除了单机通信时的性能瓶颈(但真正的分布式通信仍需自行处理数据序列化)。
核心功能函数 play_func 作为"游戏进程"的实现,将在主进程启动的子进程中运行。其核心职责包含:从环境中采集经验数据并存入共享队列,将包含回合终止信息的结构化数据(含奖励值、步数等指标)同步推送至队列,确保训练进程能实时获取环境反馈。

(2) batch_generator 函数由 BatchGenerator 类替代实现:

class BatchGenerator:
    def __init__(self, buffer_size: int, exp_queue: mp.Queue,
                 fps_handler: lib_ignite.EpisodeFPSHandler,
                 initial: int, batch_size: int):
        self.buffer = lib.experience.ExperienceReplayBuffer(
            experience_source=None, buffer_size=buffer_size)
        self.exp_queue = exp_queue
        self.fps_handler = fps_handler
        self.initial = initial
        self.batch_size = batch_size
        self._rewards_steps = []
        self.epsilon = None

    def pop_rewards_steps(self) -> tt.List[tt.Tuple[float, int]]:
        res = list(self._rewards_steps)
        self._rewards_steps.clear()
        return res

    def __iter__(self):
        while True:
            while self.exp_queue.qsize() > 0:
                exp = self.exp_queue.get()
                if isinstance(exp, EpisodeEnded):
                    self._rewards_steps.append((exp.reward, exp.steps))
                    self.epsilon = exp.epsilon
                else:
                    self.buffer._add(exp)
                    self.fps_handler.step()
            if len(self.buffer) < self.initial:
                continue
            yield self.buffer.sample(self.batch_size)

BatchGenerator 类不仅提供批量数据的迭代功能,还通过 pop_reward_steps() 方法模拟了 ExperienceSource 接口。其核心逻辑如下:持续消耗由"游戏进程"填充的队列,若接收到 EpisodeEnded 对象,则记录当前探索率 (epsilon) 及游戏步数;若为普通经验数据则存入回放缓冲池。该机制会实时清空队列可用数据,随后从缓冲池采样训练批次并生成产出值。

(3) 在训练初始化阶段,需首先指定 torch.multiprocessing 的启动方式(虽然提供多种选项,但 spawn 模式具备最佳兼容性):

if __name__ == "__main__":
    # get rid of missing metrics warning
    warnings.simplefilter("ignore", category=UserWarning)
    mp.set_start_method('spawn')

(4) 接着创建通信队列,并以神经网络模型、超参数及经验传输队列作为参数传入,将 play_func 作为独立进程启动。

exp_queue = mp.Queue(maxsize=2)
    proc_args = (params, net, args.dev, exp_queue)
    play_proc = mp.Process(target=play_func, args=proc_args)
    play_proc.start()

使用 BatchGenerator 实例作为 IgniteEndOfEpisodeHandler (需要调用 pop_rewards_steps() 方法)的数据源。基准测试结果如下所示:

如图所示,在帧率方面我们实现了 27% 的提升:并行版本达到 290 FPS,而基准版本仅为 228 FPS。解决环境问题的平均耗时降低了 41%
虽然帧率提升幅度不及上一小节的最佳结果(使用 3 个游戏环境时帧率接近 400 FPS),但当前并行版本的收敛速度更快。

6. 优化包装器

最后,我们将调优化环境包装器。这些包装器往往容易被忽视——通常编写做好后就直接套用在环境上,或是从其他代码库直接移植过来便不再调整。但我们务必意识到,这些包装器对算法运行速度和收敛性能具有重要影响。以经典的 DeepMind 风格 Atari 游戏包装器为例,其典型结构包含:

  • NoopResetEnv:在游戏重置时执行随机次数的空操作。某些 Atari 游戏需要此操作来消除异常的初始观测值
  • MaxAndSkipEnv:对 N 次观测(默认 4 次)取最大值作为单步返回的观测值。这解决了部分Atari游戏画面"闪烁"问题,开发者通过奇偶帧绘制不同画面来突破平台性能限制
  • EpisodicLifeEnv:将角色一条生命值耗尽视为一回合结束,通过缩短回合长度(从游戏逻辑规定的多条生命变为单条生命),可显著提升收敛速度。该封装仅适用于部分 Atari 2600 学习环境支持的游戏
  • FireResetEnv:在游戏重置时自动执行 FIRE 动作。某些游戏需要此操作才能正式开始,否则环境会退化为部分可观测马尔可夫决策过程 (Partially Observable Markov Decision Process, POMDP),导致无法收敛
  • WarpFrame:也称 ProcessFrame84,将图像转换为灰度并将其调整为 84×84 像素
  • ClipRewardEnv:将奖励值裁剪到 [-1,1] 区间,统一不同 Atari 游戏间差异巨大的得分体系。例如 Pong 得分范围在 -2121,而 River Raid 得分可能从 0 到无穷大
  • FrameStack:将 N 次连续观测堆叠成张量(默认 4 次),某些游戏需要满足马尔可夫性质——像是在 Pong 中,单帧画面无法判断球的运动方向
这些包装器代码经过众多开发者的深度优化,存在多个版本,本节使用 Stable Baselines3 (OpenAI Baselines 项目的分支)。 但我们不应过渡依赖此代码,因为具体应用场景可能存在特殊需求。例如,如果只想加速某个特定 Atari 游戏,可能根本不需要 NoopResetEnvMaxAndSkipEnv。另一个可调参数是 FrameStack 封装器的帧堆叠数量——虽然 DeepMind 等研究者针对 50 多款 Atari 2600 游戏采用默认值 4 帧,但在特定场景中,2 帧历史数据可能既能满足需求又能减少神经网络计算量,从而提升性能。 最后,图像调整可能是包装器的瓶颈,因此可能需要优化包装器所使用的库,例如重新构建或用更快的版本替换它们,我们可以尝试不同的缩放方法和库。 针对 Pong 游戏,我们将实施以下优化:
  • 禁用 NoopResetEnv
  • MaxAndSkipEnv 替换为简化版,只跳过四帧而不进行最大池化
  • FrameStack 帧数缩减至 2
为了检查我们调整的组合效果,我们将在之前两个部分中进行的修改(多个环境和并行执行游戏和训练)上添加这些更改。由于这些更改不复杂,我们可以快速讨论它们,而不需要实际代码(完整代码可以在以下文件中找到 wrappers_n_env.py, wrappers_parallel.pyatari_wrappers.py):
  • atari_wrappers.py 包含了 wrap_dqn 函数以及从 Stable Baselines3 移植的 AtariWrapper
  • AtariWrapper 中,原本的 MaxAndSkipEnv 类被替换为简化版本,移除了帧间最大池化操作
  • wrappers_n_env.pywrappers_parallel.py 只是 n_env.pyparallel.py 的副本,仅对环境创建逻辑进行了微调
我们还尝试将 FrameStack 中保留的帧数减少到仅 1 帧(可以通过命令行参数 --stack 1 实现),但这种情况下仍然能够解决游戏,但所需的训练局数大幅增加,并且训练过程变得不稳定(约 8 次训练中有 3 次完全无法收敛)。这可能表明:在仅 1 帧的情况下,Pong 游戏并不完全符合 POMDP 的特性,因为智能体仍然能够依靠单帧观测学会如何获胜。但训练效率受到了严重影响。

小结

在本节中,我们探讨了多种通过纯工程化手段提升强化学习性能的方法,这与算法优化思路并不相同。这两种优化方法相辅相成,既需要掌握学界最新的技巧发现,也务必深谙实现细节的精妙之处。

系列链接

PyTorch强化学习实战(1)——强化学习(Reinforcement Learning,RL)详解
PyTorch强化学习实战(2)——强化学习环境库Gymnasium
PyTorch强化学习实战(3)——Gymnasium API扩展功能
PyTorch强化学习实战(4)——PyTorch基础
PyTorch强化学习实战(5)——PyTorch Ignite 事件驱动机制与实践
PyTorch强化学习实战(6)——交叉熵方法详解与实现
PyTorch强化学习实战(7)——表格学习与贝尔曼方程
PyTorch强化学习实战(8)——Q学习详解与实现
PyTorch强化学习实战(9)——深度Q学习
PyTorch强化学习实战(10)——强化学习高级组件
PyTorch强化学习实战(11)——N步DQN(N-step DQN)
PyTorch强化学习实战(12)——Double DQN(DDQN)
PyTorch强化学习实战(13)——噪声网络(NoisyNet-DQN)
PyTorch强化学习实战(14)——优先经验回放机制
PyTorch强化学习实战(15)——Dueling DQN
PyTorch强化学习实战(16)——Categorical DQN


这篇笔记就先到这里,后面用到新的思路或者发现有问题再补充。

评论 (0)

暂无评论