|
前言
大模型时代对显存的要求越来越高,之前在BERT刚诞生时候写过一篇:GPU 显存不足怎么办?,新的这篇文章主要是重构之前的文章,来聊聊大模型时代显存不足时怎么办,没有看过的朋友直接看这篇即可。
训练时显存占用分析
训练模型时所占用的显存主要分为以下部分:模型权重参数,优化器状态,梯度,激活值。假定模型本身的大小为 A,且以 fp32 为精度计算。
模型权重参数
优化器状态与梯度
我们看到在 SGD 中,那么此时的显存占用只有梯度: 。- 以 Momentum-SGD 为例,其计算公式为:
我们看到在Momentum-SGD 中,不仅仅有梯度 ,还有动量 。

我们看到在 Adam 中,需要保存的包括:当前梯度 ,梯度加权平均 ,梯度平方的加权平均 。因此,假定模型大小为 A,训练中采用 FP32 精度进行优化,那么此时优化器状态和梯度占用的显存分别为:
- Momentum-SGD:优化器状态:4A,梯度:4A
而在实际的训练中往往采用混合精度训练,而在混合精度训练下的显存又有所区别。
激活值
激活值的显存占用与 token长度,per_gpu_batch_size,hidden_size 以及 transformer层数 正相关,并且占用显存也非常大,此处就不细写了,主要是技术很复杂,我也没算明白,哈哈哈哈。
训练时显存不足怎么办?
下面列出一些常见的节省显存的操作,优先级从高到低排列。
- 去掉compute_metrics:有些代码会在输出层后计算rouge分等,这个会输出一个batch_size*vocab_size*seq_len 的一个大向量,非常占显存。
- 采用bf16/fp16进行混合精度训练:现在大模型基本上都采用 bf16 来进行训练,但是如v100这些机器不支持,可以采用fp16进行训练。显存占用能够降低一倍。
- Flash attention:不仅能够降低显存,更能提高训练速度。
- 降低你的batch size:如上文所述,batch size 与模型每层的激活状态所占显存呈正相关,降低batch size 能够很大程度上降低这部分显存占用。
- 采用梯度累积:global batch size = batch size * 梯度累积,如果降低 batch size 后想保持你的 global batch size 不变,可以适当提高梯度累积值。
- 选择合适的上下文长度:如上文所述,上下文长度与激活状态所占显存呈正相关,因此可以通过适当降低上下文长度来降低显存占用。
- DeepSpeed Zero:显存占用从高到低为:Zero 1 > Zero 2 > Zero 2 + offload > zero 3 > zero 3 + offload,推荐最多试到 Zero2 + offload。
- 选择更小的基座模型:在满足需求的情况下,尽量选择更小的基座模型。
- Lora:能跑全参就别跑 Lora 或 Qlora,一方面是麻烦,另一方面的确是效果差点。
- Qlora:Qlora 的速度比lora慢,但所需显存更少,实在没资源可以试试。
- Megatron-LM:可以采用流水线并行和张量并行,使用比较麻烦,适合喜欢折腾的同学。
- Pai-Megatron-LM:Megatron-LM 的衍生,支持 Qwen 的sft和pt,坑比较多,爱折腾可以试试。
- 激活检查点:不推荐,非常耗时。在反向传播时重新计算深度神经网络的中间值。用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)。
最后
ok,本文到此就结束了,本文主要是对之前文章进行了细化,并补充了大模型时代下的几种显存不足时的方法。
参考
【1】https://zhuanlan.zhihu.com/p/31558973
|