链载Ai

标题: FP8 低精度训练:Transformer Engine 简析 [打印本页]

作者: 链载Ai    时间: 2 小时前
标题: FP8 低精度训练:Transformer Engine 简析

一、背景介绍

业界广泛采用 FP16、BF16 混合精度(AMP)进行模型训练。AMP 能在下游任务不掉点的前提下提升训练效率、减少显存等资源占用,如今也常用于大模型预训练、微调等任务。

NVIDIA GPU 自 Hopper 架构起支持 FP8 精度的 Tensor Core 计算,相比于 FP16/BF16 精度,FP8 具有如下优势:



ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;font-size: medium;letter-spacing: normal;text-align: start;white-space: normal;background-color: rgb(255, 255, 255);">

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;font-size: medium;letter-spacing: normal;text-align: start;white-space: normal;background-color: rgb(255, 255, 255);">我们仍然先从几个问题出发~

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;margin-top: calc(1.90909em);margin-bottom: calc(1.27273em);clear: left;color: rgb(25, 27, 31);letter-spacing: normal;text-align: start;white-space: normal;background-color: rgb(255, 255, 255);">什么是 FP8?

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;font-size: medium;letter-spacing: normal;text-align: start;white-space: normal;background-color: rgb(255, 255, 255);">相比于 16bit 精度,FP8 使用了更少的指数 bit 位和尾数 bit 位:

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;font-size: medium;letter-spacing: normal;text-align: start;white-space: normal;background-color: rgb(255, 255, 255);">

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;font-size: medium;letter-spacing: normal;text-align: start;white-space: normal;background-color: rgb(255, 255, 255);">在 NV、Arm、Intel 公布的 FP8 白皮书中[arXiv:2209.05433] 介绍了 FP8 数据的两种精度:E4M3和E5M2。两种数据格式的具体二进制表示如下表:

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;font-size: medium;letter-spacing: normal;text-align: start;white-space: normal;background-color: rgb(255, 255, 255);">

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;font-size: medium;letter-spacing: normal;text-align: start;white-space: normal;background-color: rgb(255, 255, 255);">FP8 E4M3的表示范围为[-448, 448],E5M2为[-57334, 57334],根据其数据表示范围和精度需求,一般而言,E4M3 格式更适合 weight、activation 数据,E5M2 格式更适合 grad 数据。

为什么是 FP8,不是其他精度(比如 int8)?

FP8 精度训练的效果如何?在下游任务的表现如何?

FP8 在绝大多数训练任务下都能有 FP16 相当的精度,在少部分下游任务(如数学运算)存在一定差距。

各种 CV 模型在 FP8 精度下训练的分类精度【NV测试结果】:

NLP 预训练任务【NV测试结果】:

LLM Benchmark【NV测试结果】:

SFT 微调效果:

FP8 有哪些应用场景/案例?

二、FP16/BF16 AMP

回顾 Pytorch AMP 的实现原理:

1.计算流程

通常的 FP16 AMP 计算流程为:

FP16 支持算子:https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float16(https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float16)

Pytorch 使用 AMP 的样例代码如下(注意此处精度为 BF16):

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
outputs = model(inputs)
loss = loss_func(outputs, targets)

loss.backward()
optimizer.step()
optimizer.zero_grad()

2.显存分布

FP16 AMP 训练过程中,显存包含如下数据:

3.Global Loss Scaling

由于 FP16 能够表示的数值范围更小,因此对于 FP16 精度的 AMP,需要进行 loss scaling。

scaler = torch.cuda.amp.GradScaler()

with torch.cuda.amp.autocast(dtype=torch.float16):
outputs = model(inputs)
loss = loss_func(outputs, targets)

scaler.scaled(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
关于数值范围:FP32: 1-8-23 / BF16: 1-8-7 / FP16: 1-5-10
BF16 的数值范围和 FP32 一致,均有 8 个指数位,不需要 scaler 调整数值范围。
但 FP16 仅有 5 个指数位,当原数值过大或过小时,转换到 FP16 就可能出现 overflow/underflow,对训练造成影响。

实际上,我们会维护一个全局的scale 值,并采用 Dynamic Loss Scaling 动态调整这个全局的 scale 值。即每当梯度溢出时候减少损失缩放规模,并且间歇性地尝试增加损失规模,从而实现在不引起溢出的情况下使用最高损失缩放因子,更好地恢复精度。

三、FP8 技术分析

1.宏观实现框架

FP16 所采用的 Loss Scaling 与量化的思想非常相似,它可以看成是对全局的梯度数据做离线量化(PTQ)。

FP8 的数据范围有更大的限制,单一的全局 Scale 值无法满足众多数据分布的相对精确表示,因此我们可以仿照量化的思路,将量化的基本单位缩小至 tensor(更细致的量化,如 Block-wise quantization,理论上可以用于更低精度的训练上)。

FP8 对每一个 tensor(无论是输入数据、前向计算结果、反向计算结果)都计算一个Per-tensor Scaling Factor,以此做更加细致的量化,充分利用 FP8 为数不多的格点数。

具体而言,每一次前向的 GEMM 计算需要对 3 个 tensor 记录 scale 值:input,weight,output;而相对应的反向计算需要记录 2 个 tensor 的 scale 值:grad_outputgrad_input。在 TE 的 Hybrid 模式下,前向 tensor 数据格式为 E4M3,反向 tensor 数据格式为 E5M2。两种 FP8 精度的量化方式基本相同,均采用对称线性量化,我们只需要关心 scale 值。可以参考之前的内容:

然而,FP8 训练最关键的问题是,如何在训练过程中高效地寻找到这个 scale 值?

NV 在 TE 文档中给出了两种方案:



如下图所示,如果我们知道了 scale 值,那么计算的公式和伪代码就比较直接了:

TE 框架采用 Delayed scaling 方案,即对每个 GEMM 算子用到的 tensor 记录一个 amax history 数组,当我们需要 scale 值时,就从这个数组中取出最近一段时间窗口内 amax 的最大值,以此近似现在这个 tensor 的 amax,并默认用以下方式计算 scale 值:

FP8_MAX = maximum_representable_value(fp8_format)
new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)

用户可以自定义 Delayed scaling 的策略(Recipe),例如:

2.TE 及各类框架集成方法

总的来说,任何结合 FP8 能力的框架都只需要做两件事:

  1. 使用 TE 模块搭建 model,因为计算要用到 TE 提供的 FP8 算子;

  2. 用 fp8_autocast 装饰前向计算过程。

在实际场景下,FP8 训练通常需要结合 BF16 混合精度训练。

1)TE 官方案例:

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048

# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda")

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = model(inp)

loss = out.sum()
loss.backward()

2)Accelerate:支持 DDP 和 FSDP 的 FP8 训练

# We prepare fp8 after, allowing for bf16 autocast to happen first
if getattr(self.fp8_recipe_handler, "backend", None) == "TE":
if not has_transformer_engine_layers(model):
with torch.no_grad():
convert_model(model)
model._converted_to_transformer_engine = True

kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
# If we are in DDP or FSDP, we delay `autocast` until after FSDP/DDP has been initialized
# to make use of the process group
if not self.delayed_fp8_autocast:
model.forward = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)(model.forward)

3)Megatron Core:支持 Tensor、Sequence、Pipeline 并行与 FP8 训练结合

# define TE model
use_te = args.transformer_impl == "transformer_engine"
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec()
model = GPTModel(transformer_layer_spec=transformer_layer_spec)

# set autocast context
class TransformerBlock(MegatronModule):
def forward():
if self.config.fp8:
import transformer_engine# To keep out TE dependency when not training in fp8

if self.config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif self.config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")

fp8_recipe = TEDelayedScaling(
config=self.config,
fp8_format=fp8_format,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True)
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
fp8_context = nullcontext()

with fp8_context:
# Forward pass.

3.FP8 框架 TE 计算流程

首先分析 TE 框架入口fp8_autocast的源代码:

@contextmanager
def fp8_autocast(
enabled: bool = True,
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None:
try:
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=_graph)
yield
finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)

FP8GlobalStateManager是一个单例,它保存了全局的 fp8 state 和每个 TE Module 的 fp8 scale/amax 信息。

fp8_autocast的主要工作是:

TE 模块都继承于TransformerEngineBaseModule这个基类。每个实例均有一个fp8_meta字典,这个字典包含了 fp8 的关键信息,里面记录的内容有:

经过代码分析,TE 框架的 FP8 计算流程大致如下:

4.Tensor Core 如何进行 FP8 训练

FP8 精度计算仅能运行在 Tensor Core 上。Tensor Core 的基本运算单元为 D = A*B + C,其中A、B、C、D 均为矩阵。每个 Tensor Core 能在一个时钟周期内完成 4*4 的 mma 运算,即一次矩阵乘法和一次矩阵加法。Tensor Core wmma::mma_syncAPI 的最小数据单元是 16*16 的矩阵,因此 TE 框架要求输入数据的各维度必须是 16 的倍数。

在 FP8 计算中,输入的两个矩阵可以是 FP8 两种精度的任意组合,并且 FP8 的 FLOPS 是 16bit 的两倍。两个 FP8 矩阵在完成一次 Tensor Core 运算后会输出高精度结果(FP16/FP32),因此这里存在着 FP8->FP16/FP32 以及 FP16/FP32->FP8 的精度转化过程。

四、总结与展望

FP8 训练的局限性

FP8 及更低精度训练的前景

FP8 训练在大模型场景下已具有明确的应用前景,目前也具有工业界的应用案例,因此它有望成为大模型高效训练的配置之一。

在硬件端,NV 最新的 BlackWell 架构开始支持 FP6、FP4 等更低精度的 Tensor Core 运算,并可能采用 Block-wise 的量化方案。而 Deepspeed 也推出了不依赖于硬件计算条件的 FP6 运算:参考链接(https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)

在低精度运算成为常规方案的今天,在保证训练精度不掉点,并采用低精度训练的性能提升幅度,可能还远未到达极限。

参考资料

FP8 相关论文:

TE 代码与文档:

NV 技术博客:






欢迎光临 链载Ai (https://www.lianzai.com/) Powered by Discuz! X3.5