返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

FlashAttention原理,终于看明白了!

[复制链接]
链载Ai 显示全部楼层 发表于 昨天 12:03 |阅读模式 打印 上一主题 下一主题







offer捷报

训练营学员继拿下腾讯美团字节sp后,百度经过加面,也给了sp offer,且签字费给到了其他大厂的两倍。
目前 LLM 是基于 Transformer 结构,其核心是 self-attention,随着输入序列的不断增大,时间与空间复杂度都呈二次方增长。
为了解决扩大 Transformer 模型上下文长度时面临的挑战,‌斯坦福大学和纽约州立大学布法罗分校的研究者共同提出了 FlashAttention,通过提供一种快速且内存高效的注意力算法,‌无需任何近似即可加速注意力计算并减少内存占用。‌
FlashAttention 的核心原理是将输入 QKV 分块,并保证每个块能够在 SRAM(一级缓存)上完成注意力操作,并将结果更新回 HBM,从而降低对高带宽内存(HBM)的读写操作。
总之,FlashAttention 从 GPU 的内存读写入手,减少了内存读写量,从而实现了 2~4 倍的速度提升。
FlashAttention 的核心原理

01

From Online Softmax to FlashAttention

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;line-height: 1.75em;">在计算注意力的过程中,点积可以分块累加,而 softmax 不能分块后直接处理,所以需要重新设计计算方式。

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;line-height: 1.75em;">首先观察下 softmax 的计算方式,如下式 1 所示:

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;line-height: 1.75em;">式1

importnumpyasnpinputs = np.array([0,7,6,12,10], dtype=np.float16)
eInputs = np.exp(inputs)result = eInputs/np.sum(eInputs)print(eInputs)print(np.sum(eInputs))print(result)
程序输出:
[1.000e+001.097e+034.035e+02inf2.203e+04]inf[0.0.0.nan.0.]
为了缓解这个问题,通常采用一种称为 safe-softmax 的技巧,即每个数字减去最大值再求 softmax,如下式 2。
式2
importnumpyasnpinputs=np.array([0,7,6,12,10],dtype=np.float16)max_val=max(inputs)emInputs=np.exp(inputs-max_val)result1=emInputs/np.sum(emInputs)print(emInputs)print(np.sum(emInputs))print(result1)print(sum(result1))
程序输出:
[6.139e-066.737e-032.480e-031.000e+001.354e-01]1.145[5.364e-065.886e-032.167e-038.735e-011.183e-01]0.9998794794082642
注意:safe-softmax 与 softmax 的结果一致。
import numpyasnpinputs=np.array([1,1,3,3,3], dtype=np.float16)
#softmaxeInputs=np.exp(inputs)result=eInputs/np.sum(eInputs)print(result)
# save-softmaxmax_val=max(inputs)emInputs=np.exp(inputs-max_val)result1=emInputs/np.sum(emInputs)print(result1)
程序输出:
[0.041380.041380.30570.30570.3057][0.041380.041380.30570.30570.3057]
如图 1,可将 save-softmax 写成 3 步骤。
图 1 3-pass save softmax
图 2 2-pass online softmax
图 3 di'的迭代形式
# online SoftMax2-passimport torch
L=8inputs=torch.randn(L)result=torch.zeros(L)
m=torch.tensor(float("-inf"))d=0foriinrange(L): m_new=torch.max(m, inputs[i]) d=d*(m-m_new).exp()+(inputs[i]-m_new).exp() m=m_new
foriinrange(L): result[i]=(inputs[i]-m).exp()/d
print('online softmax result:',result)print(torch.sum(result))
# save-softmax3步骤max_value=torch.max(inputs)eX=torch.exp(inputs-max_value)result1=eX/torch.sum(eX)print('save softmax result:', result1)print(torch.sum(result1))
程序输出:
onlinesoftmaxresult:tensor([0.0595,0.0548,0.3192,0.1136,0.0562,0.2336,0.0774,0.0856])tensor(1.)savesoftmaxresult:tensor([0.0595,0.0548,0.3192,0.1136,0.0562,0.2336,0.0774,0.0856])tensor(1.0000)

02

Flash Attention

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;line-height: 1.75em;">multi-pass selft-attention 其实就是结合 online softmax 的 2 步骤(图 4),那么可不可以直接写成 1 步骤呢,答案是肯定的。

图 4 多步骤的 self-attention
式3
式4
式5
图5 1-pass flash attention
图6 1-pass flash attention(Tiling)

03

代码分析

ingFang SC", "Microsoft YaHei", "Source Han Sans SC", "Noto Sans CJK SC", "WenQuanYi Micro Hei", sans-serif;line-height: 1.75em;">FlashAttention 的伪代码:

FlashAttentionV2 的伪代码:
具体数据的代码案例:
importtimeimport torch
torch.manual_seed(0)NEG_INF=-1e10 #-infinityEPSILON=1e-10
Q_LEN=6K_LEN=6Q_BLOCK_SIZE=3KV_BLOCK_SIZE=3P_DROP=0.2
Tr=Q_LEN//Q_BLOCK_SIZE # Tr 块数Tc=K_LEN//KV_BLOCK_SIZE # Tc 块数
Q=torch.randn(1,1, Q_LEN,4, requires_grad=True).to(device='cpu')K=torch.randn(1,1, K_LEN,4, requires_grad=True).to(device='cpu')V=torch.randn(1,1, K_LEN,4, requires_grad=True).to(device='cpu')
# step4Q_BLOCKS=torch.split(Q, Q_BLOCK_SIZE, dim=2)K_BLOCKS=torch.split(K, KV_BLOCK_SIZE, dim=2)V_BLOCKS=torch.split(V, KV_BLOCK_SIZE, dim=2)
print("----------------FlashAttentionV1------------------------")O=torch.zeros_like(Q, requires_grad=True)l=torch.zeros(Q.shape[:-1])[...,None]m=torch.ones(Q.shape[:-1])[...,None]*NEG_INF# print(O.shape, l.shape, m.shape)
# step5O_BLOCKS=list(torch.split(O, Q_BLOCK_SIZE, dim=2))l_BLOCKS=list(torch.split(l, Q_BLOCK_SIZE, dim=2))m_BLOCKS=list(torch.split(m, Q_BLOCK_SIZE, dim=2))# print(O_BLOCKS[0].shape, l_BLOCKS[0].shape, m_BLOCKS[0].shape)
# step6start_time1=time.time()forjinrange(Tc): # step7 Kj=K_BLOCKS[j] Vj=V_BLOCKS[j] # step8 foriinrange(Tr): # step9 Qi=Q_BLOCKS[i] Oi=O_BLOCKS[i] li=l_BLOCKS[i] mi=m_BLOCKS[i] # step10 # S_ij=torch.einsum('... i d, ... j d -> ... i j', Qi, Kj) # Qi*Kj.T S_ij=torch.einsum("... i d, ... j d -> ... i j", Qi, Kj) # step11 # mask=S_ij.ge(0.5) # S_ij=torch.masked_fill(S_ij, mask,value=0) # step12 m_block_ij, _=torch.max(S_ij, dim=-1, keepdims=True) P_ij=torch.exp(S_ij-m_block_ij) l_block_ij=torch.sum(P_ij, dim=-1, keepdims=True)+EPSILON P_ij_Vj=torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj) # step13 mi_new=torch.maximum(m_block_ij, mi) li_new=torch.exp(mi-mi_new)*li+\ torch.exp(m_block_ij-mi_new)*l_block_ij # step14 # m=torch.nn.Dropout(p=P_DROP) # P_ij_Vj=m(P_ij_Vj) # Step15 O_BLOCKS[i]=(li/li_new)*torch.exp(mi-mi_new)*Oi \ +(torch.exp(m_block_ij-mi_new)/li_new)*P_ij_Vj # print(f'-----------Attention : Q{i}xK{j}---------') # print(O_BLOCKS[i].shape) # print(O_BLOCKS[0]) # print(O_BLOCKS[1]) # print('\n') # step16 l_BLOCKS[i]=li_new m_BLOCKS[i]=mi_new
O=torch.cat(O_BLOCKS, dim=2)l=torch.cat(l_BLOCKS, dim=2)m=torch.cat(m_BLOCKS, dim=2)print(O.shape, time.time()-start_time1)print(O)
print("----------------FlashAttentionV2------------------------")O2=torch.zeros_like(Q, requires_grad=True)O2_BLOCKS=list(torch.split(O2, Q_BLOCK_SIZE, dim=2))
start_time2=time.time()foriinrange(Tr): Qi=Q_BLOCKS[i] Oi=O2_BLOCKS[i] li=torch.zeros((*Q.shape[:-2], Q_BLOCK_SIZE,1)) mi=torch.ones((*Q.shape[:-2], Q_BLOCK_SIZE,1))*NEG_INF forjinrange(Tc): Kj=K_BLOCKS[j] Vj=V_BLOCKS[j] S_ij=torch.einsum("... i d, ... j d -> ... i j", Qi, Kj) mi_new=torch.maximum(torch.max(S_ij, dim=-1, keepdims=True)[0], mi) P_ij=torch.exp(S_ij-mi_new) li=torch.exp(mi-mi_new)*li+torch.sum(P_ij, dim=-1, keepdims=True)+EPSILON P_ij_Vj=torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj) Oi=torch.exp(mi-mi_new)*Oi+P_ij_Vj mi=mi_new O2_BLOCKS[i]=Oi/li
O2=torch.cat(O2_BLOCKS, dim=2)print(O2.shape, time.time()-start_time2)print(O2)
print("----------------Standard Self-Attention------------------------")start_time3=time.time()scores=torch.matmul(Q, K.transpose(-2,-1))attention_weights=torch.softmax(scores, dim=-1)output=torch.matmul(attention_weights, V)print(output.shape, time.time()-start_time3)print(output)
程序输出:
----------------FlashAttentionV1------------------------torch.Size([1,1,6,4])0.0015511512756347656tensor([[[[0.2281,-0.2178,-0.3508,0.1571],[-0.1962,-0.6078,-0.4992,-0.5868],[0.3373,0.3694,0.2818,0.2253],[-0.3096,-0.6828,-0.4914,-0.9161],[0.0873,0.6567,0.1782,0.1638],[0.1808,-0.2194,-0.4053,0.1305]]]],grad_fn=<CatBackward0>)----------------FlashAttentionV2------------------------torch.Size([1,1,6,4])0.0009410381317138672tensor([[[[0.2281,-0.2178,-0.3508,0.1571],[-0.1962,-0.6078,-0.4992,-0.5868],[0.3373,0.3694,0.2818,0.2253],[-0.3096,-0.6828,-0.4914,-0.9161],[0.0873,0.6567,0.1782,0.1638],[0.1808,-0.2194,-0.4053,0.1305]]]],grad_fn=<CatBackward0>)----------------StandardSelf-Attention------------------------torch.Size([1,1,6,4])0.00012636184692382812tensor([[[[0.2281,-0.2178,-0.3508,0.1571],[-0.1962,-0.6078,-0.4992,-0.5868],[0.3373,0.3694,0.2818,0.2253],[-0.3096,-0.6828,-0.4914,-0.9161],[0.0873,0.6567,0.1782,0.1638],[0.1808,-0.2194,-0.4053,0.1305]]]],grad_fn=<UnsafeViewBackward0>)
FlashAttentionV1、FlashAttentionV2、Standard Self-Attention 的结果是一致的,并无差别,且整体速度 FlashAttentionV2 比 FlashAttentionV1 更快。


END


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ