offer捷报
01
From Online Softmax to FlashAttention
importnumpyasnpinputs = np.array([0,7,6,12,10], dtype=np.float16)eInputs = np.exp(inputs)result = eInputs/np.sum(eInputs)print(eInputs)print(np.sum(eInputs))print(result)
[1.000e+001.097e+034.035e+02inf2.203e+04]inf[0.0.0.nan.0.]
importnumpyasnpinputs=np.array([0,7,6,12,10],dtype=np.float16)max_val=max(inputs)emInputs=np.exp(inputs-max_val)result1=emInputs/np.sum(emInputs)print(emInputs)print(np.sum(emInputs))print(result1)print(sum(result1))
[6.139e-066.737e-032.480e-031.000e+001.354e-01]1.145[5.364e-065.886e-032.167e-038.735e-011.183e-01]0.9998794794082642
import numpyasnpinputs=np.array([1,1,3,3,3], dtype=np.float16)#softmaxeInputs=np.exp(inputs)result=eInputs/np.sum(eInputs)print(result)# save-softmaxmax_val=max(inputs)emInputs=np.exp(inputs-max_val)result1=emInputs/np.sum(emInputs)print(result1)
[0.041380.041380.30570.30570.3057][0.041380.041380.30570.30570.3057]
# online SoftMax2-passimport torchL=8inputs=torch.randn(L)result=torch.zeros(L)m=torch.tensor(float("-inf"))d=0foriinrange(L):m_new=torch.max(m, inputs[i])d=d*(m-m_new).exp()+(inputs[i]-m_new).exp()m=m_newforiinrange(L):result[i]=(inputs[i]-m).exp()/dprint('online softmax result:',result)print(torch.sum(result))# save-softmax3步骤max_value=torch.max(inputs)eX=torch.exp(inputs-max_value)result1=eX/torch.sum(eX)print('save softmax result:', result1)print(torch.sum(result1))
onlinesoftmaxresult:tensor([0.0595,0.0548,0.3192,0.1136,0.0562,0.2336,0.0774,0.0856])tensor(1.)savesoftmaxresult:tensor([0.0595,0.0548,0.3192,0.1136,0.0562,0.2336,0.0774,0.0856])tensor(1.0000)
02
Flash Attention
03
代码分析
importtimeimport torchtorch.manual_seed(0)NEG_INF=-1e10 #-infinityEPSILON=1e-10Q_LEN=6K_LEN=6Q_BLOCK_SIZE=3KV_BLOCK_SIZE=3P_DROP=0.2Tr=Q_LEN//Q_BLOCK_SIZE # Tr 块数Tc=K_LEN//KV_BLOCK_SIZE # Tc 块数Q=torch.randn(1,1, Q_LEN,4, requires_grad=True).to(device='cpu')K=torch.randn(1,1, K_LEN,4, requires_grad=True).to(device='cpu')V=torch.randn(1,1, K_LEN,4, requires_grad=True).to(device='cpu')# step4Q_BLOCKS=torch.split(Q, Q_BLOCK_SIZE, dim=2)K_BLOCKS=torch.split(K, KV_BLOCK_SIZE, dim=2)V_BLOCKS=torch.split(V, KV_BLOCK_SIZE, dim=2)print("----------------FlashAttentionV1------------------------")O=torch.zeros_like(Q, requires_grad=True)l=torch.zeros(Q.shape[:-1])[...,None]m=torch.ones(Q.shape[:-1])[...,None]*NEG_INF# print(O.shape, l.shape, m.shape)# step5O_BLOCKS=list(torch.split(O, Q_BLOCK_SIZE, dim=2))l_BLOCKS=list(torch.split(l, Q_BLOCK_SIZE, dim=2))m_BLOCKS=list(torch.split(m, Q_BLOCK_SIZE, dim=2))# print(O_BLOCKS[0].shape, l_BLOCKS[0].shape, m_BLOCKS[0].shape)# step6start_time1=time.time()forjinrange(Tc):# step7Kj=K_BLOCKS[j]Vj=V_BLOCKS[j]# step8foriinrange(Tr):# step9Qi=Q_BLOCKS[i]Oi=O_BLOCKS[i]li=l_BLOCKS[i]mi=m_BLOCKS[i]# step10# S_ij=torch.einsum('... i d, ... j d -> ... i j', Qi, Kj) # Qi*Kj.TS_ij=torch.einsum("... i d, ... j d -> ... i j", Qi, Kj)# step11# mask=S_ij.ge(0.5)# S_ij=torch.masked_fill(S_ij, mask,value=0)# step12m_block_ij, _=torch.max(S_ij, dim=-1, keepdims=True)P_ij=torch.exp(S_ij-m_block_ij)l_block_ij=torch.sum(P_ij, dim=-1, keepdims=True)+EPSILONP_ij_Vj=torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)# step13mi_new=torch.maximum(m_block_ij, mi)li_new=torch.exp(mi-mi_new)*li+\torch.exp(m_block_ij-mi_new)*l_block_ij# step14# m=torch.nn.Dropout(p=P_DROP)# P_ij_Vj=m(P_ij_Vj)# Step15O_BLOCKS[i]=(li/li_new)*torch.exp(mi-mi_new)*Oi \+(torch.exp(m_block_ij-mi_new)/li_new)*P_ij_Vj# print(f'-----------Attention : Q{i}xK{j}---------')# print(O_BLOCKS[i].shape)# print(O_BLOCKS[0])# print(O_BLOCKS[1])# print('\n')# step16l_BLOCKS[i]=li_newm_BLOCKS[i]=mi_newO=torch.cat(O_BLOCKS, dim=2)l=torch.cat(l_BLOCKS, dim=2)m=torch.cat(m_BLOCKS, dim=2)print(O.shape, time.time()-start_time1)print(O)print("----------------FlashAttentionV2------------------------")O2=torch.zeros_like(Q, requires_grad=True)O2_BLOCKS=list(torch.split(O2, Q_BLOCK_SIZE, dim=2))start_time2=time.time()foriinrange(Tr):Qi=Q_BLOCKS[i]Oi=O2_BLOCKS[i]li=torch.zeros((*Q.shape[:-2], Q_BLOCK_SIZE,1))mi=torch.ones((*Q.shape[:-2], Q_BLOCK_SIZE,1))*NEG_INFforjinrange(Tc):Kj=K_BLOCKS[j]Vj=V_BLOCKS[j]S_ij=torch.einsum("... i d, ... j d -> ... i j", Qi, Kj)mi_new=torch.maximum(torch.max(S_ij, dim=-1, keepdims=True)[0], mi)P_ij=torch.exp(S_ij-mi_new)li=torch.exp(mi-mi_new)*li+torch.sum(P_ij, dim=-1, keepdims=True)+EPSILONP_ij_Vj=torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)Oi=torch.exp(mi-mi_new)*Oi+P_ij_Vjmi=mi_newO2_BLOCKS[i]=Oi/liO2=torch.cat(O2_BLOCKS, dim=2)print(O2.shape, time.time()-start_time2)print(O2)print("----------------Standard Self-Attention------------------------")start_time3=time.time()scores=torch.matmul(Q, K.transpose(-2,-1))attention_weights=torch.softmax(scores, dim=-1)output=torch.matmul(attention_weights, V)print(output.shape, time.time()-start_time3)print(output)
----------------FlashAttentionV1------------------------torch.Size([1,1,6,4])0.0015511512756347656tensor([[[[0.2281,-0.2178,-0.3508,0.1571],[-0.1962,-0.6078,-0.4992,-0.5868],[0.3373,0.3694,0.2818,0.2253],[-0.3096,-0.6828,-0.4914,-0.9161],[0.0873,0.6567,0.1782,0.1638],[0.1808,-0.2194,-0.4053,0.1305]]]],grad_fn=<CatBackward0>)----------------FlashAttentionV2------------------------torch.Size([1,1,6,4])0.0009410381317138672tensor([[[[0.2281,-0.2178,-0.3508,0.1571],[-0.1962,-0.6078,-0.4992,-0.5868],[0.3373,0.3694,0.2818,0.2253],[-0.3096,-0.6828,-0.4914,-0.9161],[0.0873,0.6567,0.1782,0.1638],[0.1808,-0.2194,-0.4053,0.1305]]]],grad_fn=<CatBackward0>)----------------StandardSelf-Attention------------------------torch.Size([1,1,6,4])0.00012636184692382812tensor([[[[0.2281,-0.2178,-0.3508,0.1571],[-0.1962,-0.6078,-0.4992,-0.5868],[0.3373,0.3694,0.2818,0.2253],[-0.3096,-0.6828,-0.4914,-0.9161],[0.0873,0.6567,0.1782,0.1638],[0.1808,-0.2194,-0.4053,0.1305]]]],grad_fn=<UnsafeViewBackward0>)
| 欢迎光临 链载Ai (https://www.lianzai.com/) | Powered by Discuz! X3.5 |