TPU推荐系统训练全链路优化:输入管道与嵌入表性能提升实践
1. 项目概述与核心挑战在工业级深度学习推荐系统的构建中我们面临着一个核心的“剪刀差”困境一方面模型需要处理海量的稀疏特征这些特征通常被编码为规模极其庞大的嵌入表其参数量动辄达到数十亿甚至数百亿级别另一方面我们依赖像TPU这样的专用硬件加速器来获得极致的计算吞吐量但这些硬件最初是为密集、规整的矩阵运算如Transformer、MLP而设计的。如何让这两者高效协同是决定整个训练系统成本与效率的关键。我过去几年深度参与过多个超大规模推荐系统的TPU训练优化项目从输入数据准备到梯度回传几乎踩遍了每一个环节的坑。今天我想系统性地拆解一下如何针对TPU架构对推荐系统的训练流程进行全链路性能优化特别是聚焦于最棘手的嵌入表操作。简单来说优化的目标就是让昂贵的TPU芯片时刻保持“忙碌”避免它们因为等待数据输入管道瓶颈或因为低效的内存访问嵌入表查找瓶颈而空闲。这不仅仅是调几个参数那么简单它涉及到从数据流架构、计算图调度到底层硬件资源分配的一整套系统工程。本文将围绕两个最关键的子系统展开一是确保数据供给不卡脖子的输入管道二是榨干TPU SparseCore硬件潜力的嵌入表操作优化。我们会看到通过一系列组合拳包括共享输入生成、动态水平扩展、混合分区策略以及计算流水线化我们最终在真实的广告推荐模型上实现了平均超过一倍的性能提升同时显著降低了对外部CPU/内存资源的依赖。2. 输入管道优化告别TPU“饥饿”在分布式训练中TPU阵列的计算能力非常强大但如果喂给它们数据的速度跟不上其消耗速度那么这些昂贵的芯片就会处于闲置状态造成巨大的资源浪费。输入管道的优化首要目标就是解决TPU的“饥饿”问题。2.1 从“各自为政”到“共享厨房”共享输入生成服务在传统的模式即本地输入生成中每个独立的训练任务都需要运行自己专属的输入数据预处理流水线。这意味着即使有十个模型都在处理同一份用户点击日志提取相似的特征如用户历史行为序列、商品属性它们也会各自重复地进行完全一样的特征解析、转换和归一化操作。这相当于每个厨师都在自己的小厨房里从洗菜、切菜开始准备完全相同的菜品造成了CPU和内存资源的极大冗余。我们引入的共享输入生成服务其核心思想类似于一个中央厨房。这个“中央厨房”独立于任何具体的训练任务它负责执行最耗时的公共特征转换计算。工作流程如下特征图定义与子图识别所有模型将其特征处理逻辑一个计算图注册到SIG。SIG会分析这些图识别出其中计算代价高昂且被多个模型共享的子图。例如从原始日志中解析出“用户过去7天点击某类目的次数”这个特征可能涉及复杂的窗口聚合和连接操作是典型的候选。计算与缓存SIG服务内的专用工作节点会执行这些公共子图的计算并将结果即转换后的特征张量持久化到高速存储中。这个过程被称为“物化”。按需分发当各个模型的输入读取器需要数据时它们不再重复执行完整的特征转换而是向SIG请求对应的物化结果。SIG根据请求的批次和特征键快速读取并返回缓存的数据。注意SIG并非缓存所有原始数据而是缓存中间特征转换结果。原始数据可能非常庞大且变化频繁而特征转换结果相对稳定且可复用。在我们的实践中SIG的缓存命中率超过95%平均每个物化的特征子图被超过22个模型复用峰值时可达400个以上。带来的收益是立竿见影的如图6所示与LIG相比SIG将输入读取器的资源成本降低了4.3倍到7.5倍。虽然TPU的成本在总成本中占主导地位但SIG依然带来了12%到27%的总训练成本下降几何平均为18%。更重要的是它解决了资源争用问题在LIG模式下为每个任务配置足够的CPU/内存资源非常困难经常导致TPU因等待数据而利用率低下。SIG通过资源共享确保了TPU阵列能够持续获得数据供给。2.2 动态水平扩展应对波动的数据需求输入读取器的资源需求并非一成不变。它主要受两个因素影响训练阶段在初始训练阶段模型需要快速处理历史积累的巨量数据此时使用SIG输入读取负载较低。而在追新训练阶段模型需要近乎实时地学习最新产生的数据由于数据新鲜度要求高无法使用SIG的缓存必须采用LIG模式导致CPU需求激增。模型复杂度不同模型的特征数量、转换逻辑复杂度差异巨大对输入读取器的压力也不同。表I的数据清晰地说明了问题对于50%的训练流水线其输入读取所需的CPU资源尚可由TPU主机本身满足比值为0.7。但对于90分位的任务其需求是单个TPU主机资源的3.5倍对于99分位的任务甚至高达16倍。这意味着仅靠TPU主机附带的CPU资源是远远不够的。因此我们实现了水平可扩展的输入读取器服务。该服务可以独立于TPU Pod进行部署和伸缩。其核心是一个控制器它持续监控两个指标数据队列深度TPU端等待处理的数据批次队列是否即将排空输入读取器利用率当前输入读取器节点的CPU/内存使用率是否持续过高基于这些指标控制器可以动态地增加或减少输入读取器实例的数量。例如当系统进入追新训练阶段控制器会自动扩容增加输入读取器以应对LIG模式下的高负载当训练进入稳定阶段或切换回使用SIG的初始训练时则可以缩容以节省资源。实操心得动态伸缩的粒度不宜过细。我们通常以“任务组”为单位进行伸缩并设置一个冷却期避免因瞬时波动导致频繁的启停操作这反而会引入开销和不稳定。同时需要为输入读取器配置足够的网络带宽和低延迟存储访问防止其成为新的瓶颈。3. 嵌入表操作优化征服TPU上的稀疏计算对于推荐模型嵌入表查找和更新是训练过程中最核心、也最耗时的稀疏操作。TPUv4通过集成专用的SparseCore硬件来高效处理这些操作但如何用好SC并与负责密集计算的TensorCore协同工作是性能优化的重中之重。3.1 理解计算流程与瓶颈一个典型的推荐模型训练步骤中嵌入层操作流程如下前向传播输入读取器提供一批训练样本每个样本包含多个特征ID。主机CPU首先对这些ID进行去重得到一批唯一的特征值。嵌入查找去重后的ID被发送给SparseCore。SC根据这些ID从其管理的分布式嵌入表分区中并行地查找并收集对应的嵌入向量。归约求和由于一个特征ID可能在同一个批次的不同样本中出现多次即“多值”特征SC收集到的是每个唯一ID的向量。接着这些向量被传递给TensorCoreTC根据样本-特征的映射关系执行段求和操作为每个样本生成其对应的聚合后的嵌入向量作为后续MLP等密集层的输入。反向传播梯度从密集层反向传播到嵌入层。TC计算得到每个本特征对应的嵌入梯度这些梯度需要根据原始ID映射分散更新回SC中对应的嵌入表行。这个流程的瓶颈非常明显负载不均衡不同特征ID的出现频率热度差异巨大导致某些SC核心需要处理的热门ID远多于其他核心。串行执行在朴素的实现中前向传播必须等待所有SC嵌入查找完成后TC才能开始密集计算反向传播也必须等待TC梯度计算完成后SC才能开始更新。SC和TC交替空闲硬件利用率低。内存带宽压力嵌入表通常远超单个SC的HBM容量必须分区存放。低效的分区策略会导致跨芯片通信频繁挤占宝贵的ICI带宽。3.2 核心优化策略一混合分区单纯按行分区将不同的嵌入表行分布到不同的SC上是最直观的方法但面对高度倾斜的访问分布时效果很差。假设有4个SC两个嵌入表T1和T2每表4行。每行的平均访问次数分别为0.6, 0.3, 0.2, 0.1。如果采用行分区最热门的行0.6落在某个SC上最冷的行0.1落在另一个SC上那么负载不均衡因子高达(4 * 0.6) / (0.60.30.20.1) 2意味着最忙的SC工作量是平均值的2倍。我们引入了混合分区策略它结合了三种维度表级分区将不同的整个嵌入表放置到不同的SC集合上。适用于表之间大小和访问模式差异大的场景。行级分区将一个大表的行分散到多个SC上。这是最基础的负载分散方法。列级分区将一个嵌入向量的不同维度列切分到不同的SC上。这是提升性能的关键。继续上面的例子采用混合分区首先进行表分区T1放在{SC0, SC1}T2放在{SC2, SC3}。然后对每个表进行列分区将每个64维的向量从中间切开前32维放在一组SCSC0, SC2后32维放在另一组SCSC1, SC3。这样每个SC最终负责存储SC0: T1的前32列SC2: T2的前32列SC1: T1的后32列SC3: T2的后32列。计算负载被完美均摊负载不均衡因子降为1。列分区的额外好处它降低了内存访问的粒度。在行分区中即使只需要嵌入向量的一部分也必须读取整行例如64个浮点数。而在列分区中可以只读取所需的那部分列。这减少了对HBM带宽的占用虽然可能增加一次通信如果需要完整的向量但在带宽成为瓶颈的场景下收益显著。注意事项混合分区策略的寻址逻辑变得复杂。在查找时系统需要知道目标ID的行分布在哪个SC以及其列切分情况然后向多个SC发起并行的查找RPC请求最后在TC侧将部分向量拼接成完整的嵌入向量。这需要运行时系统有精密的元数据管理和路由机制。3.3 核心优化策略二反馈导向的分区混合分区解决了静态的负载不均衡问题但还有一个动态挑战多值特征。例如“用户最近点击的100个商品”这个特征其值的数量称为“价态”在运行时是变化的且依赖于数据分布编译时无法预知。价态高的特征其嵌入查找和归约计算量也大。反馈导向的分区的核心思想是利用运行时收集的剖析信息来指导分区决策。系统在训练过程中会持续采样并记录每个训练批次中各个特征出现的总次数。每个特征ID的唯一值数量即去重后的价态。各SC核心的实时负载和通信量。这些统计信息被汇总到一个“剖析数据库”中。在定期或触发式的分区决策时刻系统会利用这些数据将高频、高价态的特征所对应的嵌入表行或列进行更细粒度的拆分或迁移以平衡SC间的计算和通信负载。在我们的实验中FDP为某些模型带来了额外的19%到21%的性能提升。3.4 核心优化策略三TC/SC流水线执行这是提升系统吞吐量的“神来之笔”。观察图5在严格的串行执行中SC和TC如同两个必须交接棒的运动员大部分时间总有一个在等待。流水线执行打破了这一步的严格依赖。其原理是允许SC提前开始下一步Step N1的嵌入查找而TC仍在处理当前步Step N的密集层计算。从数学上看这相当于在反向传播更新嵌入时使用的是上一步Step N的梯度而非当前步Step N1的梯度即梯度延迟了一拍。为什么这可行且有效梯度延迟的容忍性在推荐模型这种超大规模、数据噪声丰富的场景下训练过程本身具有很强的随机性如大规模SGD。延迟一拍的梯度可以看作是在梯度中引入了一个微小的、有偏的噪声。大量实验表明这种延迟对模型的最终收敛质量和效果没有可观测的负面影响。硬件利用率大幅提升如图7所示流水线化后训练每一步的时间从TC_time SC_time缩短为max(TC_time, SC_time)。只要TC和SC的计算时间不是严重失衡就能获得接近线性的加速。这对于那些嵌入层计算和密集层计算耗时相近的模型尤其有效。实操心得启用流水线会加剧对共享资源如HBM和ICI的争用因为TC和SC同时在活跃地访问内存和通信。因此需要仔细监控这些资源的利用率并可能需要对混合分区策略进行微调以平衡计算负载和通信压力。通常这是一个迭代调优的过程。4. 系统级保障鲁棒的资源与容错管理优化性能的同时必须保证系统的稳定性和资源效率。在共享数据中心训练任务动辄运行数周甚至数月期间必然会遇到各种干扰。4.1 智能错误处理与训练挂起训练任务可能进入无法继续的状态。我们将其分为两类并采取不同策略永久性错误如模型配置错误、编译失败、内存溢出、数值溢出NaN。系统检测到此类错误后会主动放置一个训练挂起信号。这意味着任务会暂停并释放所有TPU和输入读取器资源等待工程师介入排查。这避免了宝贵的TPU资源被一个注定失败的任务长期占用。瞬时性停滞最常见的原因是输入数据尚未就绪例如SIG仍在物化当前训练所需的数据范围。训练管道能检测到这种“数据未准备好”的状态并同样触发训练挂起释放计算资源。但与永久错误不同系统的控制器会保持活跃持续轮询数据可用性。一旦数据准备就绪控制器会自动解除挂起任务无缝恢复。这种机制至关重要。如表II所示在实际系统中处于“挂起”状态的模型所需求的TPU芯片量是正在活跃训练芯片量的2.49倍。如果不能快速识别并挂起出错任务这些“僵尸”任务将严重浪费集群资源。4.2 优雅的抢占处理在共享环境中高优先级任务、软件滚动更新或硬件维护都可能要求当前任务提前终止。粗暴地杀死任务会导致当前训练周期epoch的进度全部丢失。我们实现了抢占通知协议。外部调度系统在决定要抢占某个任务时会提前发送一个“预抢占通知”。任务接收到通知后通知控制器任务主进程通知中央控制器即将关闭。广播与收尾控制器广播此消息给该任务的所有工作节点如输入读取器。所有节点尝试完成手头正在处理的工作单元。检查点保存训练任务迅速将当前进度模型参数、优化器状态等保存到持久化存储中。优雅退出保存完成后任务主动退出。等待重启控制器会阻塞下一个训练周期的开始直到该任务在新的资源上重启并重新加入训练管道。这确保了被抢占的epoch进度得以保留实现了“断点续训”。在我们的实践中约61%的被抢占epoch成功保存了进度并得以恢复极大地减少了因中断造成的计算浪费。5. 性能评估与效果分析我们在一套由128个TPUv4芯片组成的系统上选取了五个占生产负载超过50%的代表性推荐模型进行评估。这些模型的密集部分参数量在5000万到3亿之间并包含数百个大小、维度、访问模式各异的嵌入表。5.1 嵌入优化效果分解图7清晰地展示了各项优化技术的累积效果基线仅使用行分区且TC/SC串行执行。这是最朴素的实现。流水线启用TC/SC流水线执行。所有模型均获得提升但对于SC瓶颈显著的模型A和E提升幅度相对较小因为其SC耗时远大于TC流水线后整体耗时仍由SC决定。混合分区在流水线基础上启用行、列、表混合分区。通过更精细的负载均衡所有模型性能进一步上涨。特别是对于之前SC瓶颈的模型负载被均摊后SC耗时下降与TC的耗时更加匹配从而让流水线的效果得以充分发挥。反馈导向分区在前两者基础上加入基于运行时剖析的FDP。对于模型B和E由于存在价态变化剧烈的多值特征FDP带来了额外的19%和21%的性能飞跃。模型A因主要负载集中在少数几个表上即使不用FDP也能较好分布因此提升有限。模型C和D在混合分区后已变为TC瓶颈因此FDP对它们没有进一步帮助。最终这一系列嵌入优化技术为五个模型带来了58%到180%的性能提升几何平均提升高达116%。5.2 成本与效率的权衡优化不仅是追求速度更是追求性价比。SIG通过共享计算将输入读取器的成本降低了数倍虽然对以TPU成本为主的总成本影响比例看似不大平均降低18%但其战略意义在于它使得为庞大的TPU集群持续供给数据成为可能从而将TPU的利用率提升到了一个新的高度。没有SIGTPU将因为输入瓶颈而大量闲置实际训练成本会成倍增加。6. 经验总结与未来展望回顾整个优化历程有几点深刻的体会全链路视角至关重要不能只盯着TPU矩阵乘的峰值算力。输入管道、嵌入查找、通信、容错任何一个环节的短板都会成为整个系统的瓶颈。必须像对待核心算法一样对待这些系统工程问题。拥抱“不精确”以换取效率流水线执行带来的梯度延迟本质上是一种用极小的、经验证可忽略的精度代价换取巨大硬件利用率提升的权衡。在工业级系统中这类基于对问题深刻理解的、可控的近似优化往往是突破性能瓶颈的关键。数据驱动的动态调优无论是反馈导向的分区还是输入读取器的动态伸缩都依赖于对运行时数据的持续收集和分析。静态的、一刀切的配置无法应对生产环境中复杂多变的工作负载。未来的探索方向也基于当前实践的延伸更智能的混合存储目前嵌入表全部驻留在TPU HBM中。未来可以探索分层存储将访问频率极低的“冷”嵌入行或表卸载到主机内存甚至SSD上从而在有限的HBM容量下支持更大的模型。这其中的挑战在于如何高效、动态地识别和迁移“冷热”数据。SIG的演进一是进一步减少存储开销例如只物化计算图中最昂贵的部分子图而非全图。二是探索支持可变数据的SIG使得追新训练也能受益于缓存共享但这会引入数据一致性和生产风险共享的复杂问题。优化大规模推荐系统的TPU训练是一场持续在算法、系统、硬件交叉地带进行的工程探险。每一次性能的百分比提升背后都是对数据流、计算图、硬件资源更精细的雕刻与编排。希望这些从实战中总结出的思路和细节能为同行们提供一些有价值的参考。