本笔记参考了 AISystem开源课程。
混合精度训练 (mixed-precision training) 通过在模型训练过程中使用不同的数值精度来达到加速训练和减少内存消耗的目的。其核心思想在于:
- 低精度计算:前向传播和反向传播时,使用较低的精度 (e.g.,
float16)。 - 高精度参数:16 位浮点数的范围和精度有限。为了维持训练稳定性和模型性能,使用高精度来保存和更新模型参数 (e.g.
float32)。 - 损失缩放:将 loss 乘上一定倍率,使梯度相应变大,防止梯度下溢。
Note
本文以
float16作为低精度浮点类型。实践中,更推荐使用bfloat16,因为其表示范围更大。
具体而言,模型的权重始终以 float32 形式存储;这份高精度的权重被称作master weight。混合精度训练每一轮的执行过程为:
- 基于当前 master weight 获取一份
float16的权重副本。master weight 仍保留在内存中。 - 用
float16的权重和激活值执行前向传播,得到float16的 loss。 - 施加损失缩放,将 loss 放大 倍。
- 使用放缩后的 loss、权重和中间激活值 (均是
float16) 完成反向传播,得到float16梯度。此时的梯度值是常规梯度值的 倍(见下式),故不容易下溢。
- 将
float16梯度 转化为float32梯度,随后除以 做反缩放。float32可以表示更小的数,所以除以 后也不容易下溢。 - 用
float32梯度更新float32master weight。
换言之,只有 master weight 以及 master weight 的更新量采用 float32 表示;其余的临时权重、激活值等张量全部采用 float16 表示。
--- title: Mixed Precision Training displayMode: compact --- flowchart LR mw{{"Master-W *float32*"}} w{{"W *float16*"}} x{{"x *float16*"}} y{{"y *float16*"}} l{{"L *float16*"}} al{{"Scaled L *float16*"}} gx{{"Grad x *float16*"}} gy{{"Grad y *float16*"}} gw{{"Grad W *float16*"}} gw2{{"Unscaled Grad W *float32*"}} mw'{{"Updated Master-W _float32_"}} wu["Weight Update"] fwd["forward"] bwd["backward"] mw --> w mw --> wu wu --> mw' x --> fwd w --> fwd fwd --> y y -. forward .-> l l -- "scale up" --> al al -. backward .-> gy gy --> bwd x --> bwd w --> bwd bwd --> gx bwd --> gw gw --typecast, unscale--> gw2 gw2 --> wu