链载Ai

标题: 一起聊聊Nvidia Hopper新特性之WGMMA [打印本页]

作者: 链载Ai    时间: 4 小时前
标题: 一起聊聊Nvidia Hopper新特性之WGMMA


上次为大家带来了Hopper上的新特性之TMA,这次我们来一起看看Hopper上的新矩阵乘法操作WGMMA。

引子

如果一个 CUDA 教程没有关于通用矩阵乘法(GEMM)的章节,那么就是不完整的。可以说,GEMM 是现代 GPU 上最重要的例程,它在神经网络、大型语言模型和许多图形应用程序中构成了大部分计算。尽管 GEMM 无处不在,但它以难以有效实现而闻名。

这个由三部分组成的教程系列旨在让读者全面了解如何使用 CUTLASS 库在 NVIDIA Hopper GPU 上编写高效的 GEMM 内核。

本系列的三个部分大致遵循通用矩阵乘(GEMM)内核的整个开发过程,但采用“由内向外”的方式。首先,我们有按分块进行的 GEMM 基本操作,它调用张量核心(Tensor Cores)来最终进行计算。其次,我们有从每个线程束协同线程组(CTA)角度看到的 GEMM 内核设计——由序言、主循环和尾声组成——其中主要挑战是避免内存加载成为快速张量核心的瓶颈。最后,我们在最外层网格级别对 CTA 进行调度,此时负载平衡考虑因素成为首要问题。

我们希望在阅读完本系列后,读者将成为 GEMM 算法的专家,并能够利用该算法中的一些优秀理念来设计和实现他们自己工作中的其他内核。

Asynchronous Warpgroup MMA (WGMMA)

Hopper 引入了异步线程束组级矩阵乘法累加运算(WGMMA)。一个线程束组由四个连续的线程束组成,即 128 个连续的线程,其中第一个线程束的线程束编号是 4 的倍数。wgmma.mma_async指令由线程束组中的所有 128 个线程共同执行。此操作通常采用以下形式之一,其中矩阵C用作累加器:

WGMMA 的一个显著要求是操作数B必须始终存储在共享内存(SMEM)中。相比之下,操作数A可以位于共享内存或寄存器内存(RMEM)中,并且累加器C始终保存在 RMEM 中。

这篇博客文章的结构如下。首先,我们讨论在 CUTLASS 中调用wgmma.mma_async指令的要点。这涉及构建相关的TiledMMA,以及创建和划分 SMEM 张量以与 WGMMA 兼容。其次,我们讨论确保 WGMMA 正确性所需的同步机制。最后,我们更详细地讨论 WGMMA 中使用的布局,包括来自 SMEM 的操作数的核心矩阵和矩阵描述符的概念。

在整个过程中,为了简洁起见,我们将wgmma.mma_async缩写为wgmma

CUTLASS kernel 中的WGMMA

在本教程中,我们的主要目标是解释用于调用 Hopper Tensor Cores 进行基于分块的 GEMM 的wgmma原语,以及如何将其作为cute::gemm调用的一部分进行调用。为了做好准备,考虑一个标准的 GEMM 内核,它接收维度为MxNxK的输入矩阵A和B,并计算C = A*B。为了并行化计算,内核固定静态分块大小bM、bN和bK,并启动一个由⌈M/bM⌉x⌈N/bN⌉多个线程块(CTAs),每个 CTA 计算输出矩阵的一个bMxbN瓦片rC。这将在被写回到全局C矩阵之前保存在 CTA 的本地内存(RMEM)中。

根据CTA,我们就有了内核的主循环。通过多次迭代,我们循环内部维度,并将A和B的bMxbK和bNxbK块依次从全局加载到共享内存中,作为sA和sB;请注意,在CUTLASS中,我们将sB的形状固定为数学上的转置。(事实上,反映了常见的做法,我们将A和B的块加载到循环ME M缓冲区中,其中的级数由编译时整数给出,例如2或3。然后sA和sB的形状元组的最后一种模式由该阶段计数给出。)cute::gemm调用然后计算sA和sB的(分阶段切片)的乘积,并将值连续累加到rC中。主循环完成后,最后将rC写入全局内存。

现在,我们希望解释以下cute::gemm调用及其参数。

template <class TiledMMA, ... >
__global__ device_gemm(TiledMMA tiled_mma, ...) {
// PROLOGUE
// ...
// Define A/B partitioning and C accumulators
ThrMMA thr_mma = tiled_mma.get_thread_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)

// Allocate accumulators and clear them
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
clear(tCrC);

// Allocate"fragments"
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)

// PIPELINED MAIN LOOP
while(k_tile_count > -K_PIPE_MAX) {
// ...
// MMAs to cover 1 K_TILE
cute::warpgroup_arrive();
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC);
cute::warpgroup_commit_batch();
// Waitforall MMAsina K_TILE to complete
cute::warpgroup_wait<0>();
// ...
}

// EPILOGUE
// ...
}

在 CUTLASS 的 MMA(矩阵乘法累加)范式中,“MMA范式”里的cute::gemm方法旨在通过统一的接口展示特定架构的 MMA 指令。(实际上,如果你查看SM80 教程的 GEMM 内核,你会看到那里的cute::gemm调用在语法上与上述相同。)然而,cute::gemm调用所涉及的参数定义包含许多 WGMMA 特定的方面:

最后,当然,在cute::gemm调用周围有线程组同步原语。我们将依次解释所有这些概念。

WGMMA中的TiledMMA对象

以下内容中,假设数据类型为 FP16,且A和B是MN,所以在 BLAS 表示法中,我们正在计算一个NT gemm。我们使用cute::make_tiled_mma方法在主机上构造TiledMMA对象,如下所示:

TiledMMA tiled_mma = cute::make_tiled_mma(
SM90_64x64x16_F16F16F16_SS<GMMA::Major::MN,GMMA::Major::MN>{});

虽然cute::make_tiled_mma也有一些可选参数,但让我们专注于当前的这个参数——矩阵乘法累加原子(MMA Atom)。这是一个结构体,它封装了一个底层的 PTX 调用,在这种情况下是:

wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16

CUTLASS 符号系统使得人们可以立即读出包装后的 PTX 指令与 MMA 原子之间的关系。首先,SM90 是 Hopper 架构的另一个名称。然后,SM90 MMA 原子被标记为SM90_MxNxK_XYZ_SS或SM90_MxNxK_XYZ_RS,其中有两个模板参数可以是GMMA::Major::MNGMMA::Major::K。它们的含义如下:

这就是你需要了解的 MMA Atom 的语法!现在,我们已经强调过 WGMMA 是一个全线程组指令。在代码中,你可以使用其大小来检索参与由 TiledMMA 对象定义的 MMA 操作的线程数量。例如,以下主机代码。

dim3 dimBlock(cute::size(tiled_mma));

规定内核中的每个 CTA 以 1 个包含 128 个线程的线程束组启动。假设我们想要2 个线程束组来执行 WGMMA,由不同的线程束组独立计算输出块的一半(并且每个线程束组发出各自的wgmma指令)。为此,我们可以将一个非平凡的布局(AtomLayoutMNK)作为第二个参数传递给make_tiled_mma方法。例如,以下代码。

TiledMMA tiled_mma = make_tiled_mma(
SM90_64x64x16_F16F16F16_SS{},
Layout<Shape<_2,_1,_1>>{});

定义了一个 WGMMA 操作,其中 warp 组 1 和 2 分别计算输出瓦片的上半部分和下半部分,沿M模式划分(现在假设bM是 128 的倍数)。此外,size(tiled_mma)将等于 256。

一般来说,make_tiled_mma的两个可选布局参数——AtomLayoutMNK和PermutationMNK——对于任何 MMA 原子都同样适用。

共享内存的布局约束了WGMMA

接下来,我们解释在给定 MMA 原子选择的情况下,共享内存中操作数矩阵的瓦片大小和布局的约束。首先,对于任何 MMA 指令,MMA 原子的MxNxK需要能够整除操作数和累加器Tile的大小。在我们的例子中,这意味着bM应该是 64 的倍数,bN是 64 的倍数,bK是 16 的倍数。

其次,WGMMA 对sA和sB的共享内存布局(包括形状和跨度)施加了一个额外的约束,并且这个约束会随着所选的交错模式而变化。特别是,(分阶段切片的)sA的布局通常不是简单的(bM,bK)1,bM)或(bM,bK)bK,1),sB也是如此。

为了深入理解这些要求,我们需要“核心矩阵”的概念,我们将在下面介绍。然而,实际上,我们总是可以使用 CUTLASS 提供的某些预定义布局原子,然后使用cute::tile_to_shape方法构建保证与

auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 64>{};
auto bP = Int< 3>{}; // Pipeline

auto sA = cute::tile_to_shape(
GMMA:ayout_MN_SW128_Atom<T>{},
cute::make_shape(bM, bK, bP)
);
auto sB = cute::tile_to_shape(
GMMA:ayout_MN_SW128_Atom<T>{},
cute::make_shape(bN, bK, bP)
);

在这里,MN表示布局原子适用于MN主操作数,而SW128是 128 字节的交错模式。输出sA或sB会显示。

Sw&lt;3,4,3> o smem_ptr[16b](unset) o ((_64,_2),(_8,_8),_3)(_1,_512),(_64,_1024),_8192)

这个布局是从哪里来的?cute::tile_to_shape采用一个布局(同名的tile)并复制它以平铺在更大的形状上(类似于numpy.tile)。抛开swizzle函数Sw<3,4,3>,我们知道布局原子由(64,8)1,64)给出,并以列主要方式平铺在形状(128, 64, 3)上,因此对于MxK形状,512的较小外步幅位于M模式,而1024的较大外步幅位于K模式。(8192的最大步幅在于阶段计数P模式,这是有道理的,因为sA或sB的不同阶段切片不应该在内存中混合。)

请注意,64乘以sizeof(half_t)等于128字节,这是swizzle模式的名称。这是设计:由于核心矩阵的工作方式,我们总是在连续方向上安排布局原子的长度以等于swizzle字节数-对于无swizzle,可以是16,或者32、64或128之一。

相对的,如果我们考虑:

auto sA = cute::tile_to_shape(
GMMA:ayout_K_SW128_Atom<T>{},
cute::make_shape(bM,bK,bP)
);
auto sB = cute::tile_to_shape(
GMMA:ayout_K_SW128_Atom<T>{},
cute::make_shape(bN,bK,bP)
);

打印sA会得到我们预期的结果。

Sw&lt;3,4,3> o smem_ptr[16b](unset) o (_128,_64,_3)_64,_1,_8192)

由于我们改为在(8,64)64,1)上平铺(8,64)64,1)。(请注意,布局((_8,_16),(_64,_1),_3)(_64,_512),(_1,_0),_8192)合并为(_128,_64,_3)_64,_1,_8192))。

一般来说,我们可以在8种布局原子的可能性中进行选择,它们对应于MN或K为主以及四种混洗模式之一:

GMMA:ayout_MN_INTER_Atom<T>
GMMA:ayout_MN_SW32_Atom<T>
GMMA:ayout_MN_SW64_Atom<T>
GMMA:ayout_MN_SW128_Atom<T>

GMMA:ayout_K_INTER_Atom<T>
GMMA:ayout_K_SW32_Atom<T>
GMMA::Layout_K_SW64_Atom<T>
GMMA::Layout_K_SW128_Atom<T>

然后,必须将这些布局原子传入tile_to_shape,其中sA和sB的共享内存(SMEM)形状由make_shape(bM,bK,bP)make_shape(bN,bK,bP)给出,形状的模式按照该顺序给出,使得布局原子的分块大小能够整除较大的 SMEM 形状的分块大小。这最终是由混洗模式的选择对 SMEM 形状造成的约束,并且与由矩阵乘法累加(MMA)原子形状施加的另一个约束分开。

WGMMA 片段和描述符

我们创建了TiledMMA对象,并在主机上相应地准备了共享内存(SMEM)布局。现在,在设备上,我们可以使用TiledMMA对象tiled_mma来构建适当的分区张量,以便传递到cute::gemm调用中。首先,我们通过在tiled_mma上调用带有线程索引的get_thread_slice方法来创建一个名为thr_mma的ThrMMA对象。在我们的例子中,线程索引从0到127。

接着,参考上面的内核代码片段,打印张量tCsA和tCsB对于任何线程索引,显示如下:

tCsA: Sw&lt;3,4,3>_smem_ptr[16b](0x7f8800000400) o
((_64,(_8,_2)),_2,_4,_3)(_1,(_64,_1024)),_512,_2048,_8192)
tCsB: Sw&lt;3,4,3>_smem_ptr[16b](0x7f880000c400) o
((_64,(_8,_2)),_2,_4,_3):((_1,(_64,_1024)),_512,_2048,_8192)

根据注释,tCsA的形状应被视为(MMA,MMA_M,MMA_K,PIPE):

tCrA: GMMA:escriptorIterator o (_1,_2,_4,_3):(_0,_64,_256,_1024)
tCrB: GMMA:escriptorIterator o (_1,_2,_4,_3):(_0,_64,_256,_1024)

在内部,CUTLASS 构造一个“矩阵描述符”,这是一个保存在寄存器中的 64 位值,以一种适合wgmma指令使用的方式描述共享内存(SMEM)。对于程序员来说,最重要的是要记住,共享内存的值不会被复制到寄存器内存(RMEM)中;相反,访问 tCrA 和 tCrB 的值实际上是访问这些 64 位描述符。此外,这些张量作为“迭代器”意味着在任何时候,对于给定的wgmma指令,只有一个 64 位描述符保存在寄存器中(例如,与全部 24 个不同)。

与操作数相比,累加器张量以更标准的方式定义。打印线程 0 的tCgC和tCrC显示:

tCgC: gmem_ptr[16b](0x7f877a780000) o ((_2,_2,_8),_2,_2):((512,_8,4096),_64,32768)
tCrC: ptr[16b](0x7feee1fffbe0) o ((_2,_2,_8),_2,_2):((_1,_2,_4),_32,_64)

tCgC是输出 GMEM 张量的一部分,我们希望在尾声中将累加器的值复制到该部分,而tCrC是为了在主循环中计算这些值时保存这些值而创建的基于寄存器的张量。这些张量的(MMA,MMA_M,MMA_N)形状可以如下解释:在 MMA 原子的MxN=64x64输出块中,128 个线程中的每个线程都持有32=2*2*8个值,并且MMA_M=MMA_N=2与tCsA和tCsB相同。

每个线程以一种需要将 32 分解为(2,2,8)形状的方式持有原子的 32 个值,以便能够为tCgC的布局定义相应的步长。具体的分区模式可以从取自 PTX 文档的这张图片中读出:

这说明了重复的 Z 模式,其中一个线程的 32 个值被保存。例如,线程 0 保存着(0,0)、(0,1)、(8,0)、(8,1)处的值,并每向右 8 列重复一次。

Gemm call

让我们回到上面内核代码片段的第 25 行:

// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC);

cute::gemm方法的各种重载首先用于循环遍历外部模式MMA_M/N和MMA_K。一旦选择了这些坐标,我们就使用矩阵乘法累加器原子瓦片形状进行计算。换句话说,我们首先将其简化为针对cute::gemm的调度形状(V)x(V)=>(V)的重载。

然后,代码调用矩阵乘法累加器原子的fma操作(确切地说,在矩阵乘法累加器解包(mma_unpack))。这里包含了一些PTX汇编代码:

CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t& d00, uint32_t& d01, uint32_t& d02, uint32_t& d03,
uint32_t& d04, uint32_t& d05, uint32_t& d06, uint32_t& d07,
uint32_t& d08, uint32_t& d09, uint32_t& d10, uint32_t& d11,
uint32_t& d12, uint32_t& d13, uint32_t& d14, uint32_t& d15,
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
{
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15},"
" %16,"
" %17,"
" p, %19, %20, %21, %22;\n"
"}\n"
:"+r"(d00),"+r"(d01),"+r"(d02),"+r"(d03),
"+r"(d04),"+r"(d05),"+r"(d06),"+r"(d07),
"+r"(d08),"+r"(d09),"+r"(d10),"+r"(d11),
"+r"(d12),"+r"(d13),"+r"(d14),"+r"(d15)
:"l"(desc_a),
"l"(desc_b),
"r"(int32_t(scale_D)),
"n"(int32_t(scaleA)),
"n"(int32_t(scaleB)),
"n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
#else
CUTE_INVALID_CONTROL_PATH(
"Attempting to use SM90_64x64x16_F16F16F16_SS "
"without CUTE_ARCH_MMA_SM90A_ENABLED");
#endif
}

这种语法对应的 PTX 文档在此处。与上述对张量 tCrA、tCrB 和 tCrC 的描述一致,请注意,对于操作数我们有 uint64 类型的变量 desc_a 和 desc_b,同时对于累加器有 16 个 uint32 类型的变量。scale_D 的值为 0 或 1,它控制累加器是否进行零初始化。

此外,变量 scaleA、scaleB、tnspA 和 tnspB 是在 fma 方法外部通过模板参数在编译时确定的。scaleA 和 scaleB 的值为 1 或 -1,用于对操作数取负;而 tnspA 和 tnspB 表示是否对操作数进行转置,当值为 0 时对应GMMA::Major::K,值为 1 时对应GMMA::Major::MN

WGMMA的同步

接下来还需解释围绕 cute::gemm 调用的同步原语:

cute::warpgroup_arrive();
cute::gemm(tiled_mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC);
cute::warpgroup_commit_batch();
cute::warpgroup_wait<0>();

为什么这些额外的指令完全有必要呢?这一切都与 wgmma 作为一条异步指令的特性有关。在霍普(Hopper)架构的背景下,“异步” 意味着 wgmma 可以与其他操作并发运行,因此对于有依赖关系的步骤而言,就需要一种同步机制。这种机制在 PTX 内存一致性模型中有详细阐述。代码中如果同步不当,可能会导致以下情况:(a)出现难以察觉的竞态条件,进而引发棘手的错误;(b)编译器会将 wgmma 指令按顺序执行,这可能会导致性能大幅下降;或者(c)出现未定义行为。

cute 方法封装了以下 PTX 指令:

(注意,我们一直用 wgmma 作为 wgmma.mma_async 的简写,但仅在本小节我们会明确区分二者。)让我们把这些指令的用法与从 PTX 文档中逐字引用的以下基于 WGMMA 的通用矩阵乘法(GEMM)描述联系起来:

我们按顺序解释这些要点。首先,wgmma.fence指令可确保wgmma.mma_async仅在对某些寄存器内存(RMEM)地址的所有先前访问完成后,才访问这些地址。如果没有wgmma.fence,其行为是未定义的。该规则的一个例外是,霍普(Hopper)架构允许同时执行多条wgmma.mma_async指令。只要这些wgmma.mma_async指令的累加器形状相同,它们就可以共享同一个累加器张量,即写入相同的寄存器内存地址。在这种情况下,就不需要同步(fence)操作。例如,在cute::gemm调用中对MMA_K进行循环时,我们无需插入wgmma.fence

与张量内存访问(TMA)操作一样,wgmma.mma_async是在异步代理中执行的。因此,如果在通用代理中执行的操作会影响wgmma.mma_async读取的共享内存(SMEM),我们就需要发出fence.proxy.async指令。例如,如果我们通过普通的ld.global/st.shared操作将矩阵 A 和 B 复制到共享内存中,就会出现这种情况。由于我们使用了 TMA 加载,在示例中就不需要fence.proxy.async,实际上,它也未出现在 WGMMA 教程代码或 CUTLASS 霍普架构通用矩阵乘法(GEMM)内核的主循环中。(要验证这一点,请注意fence.proxy.async是由cutlass::arch::fence_view_async_shared()封装的。)

wgmma.commit_group指令会为每个线程束组创建一个新的wgmma组,并将执行线程束组发起但尚未提交到任何wgmma组的所有先前的wgmma.mma_async指令批量处理到这个新的wgmma组中。在我们的示例中,cute::warpgroup_commit_batch()会将MMA_M * MMA_N * MMA_Kwgmma.mma_async指令批量处理到一个wgmma组中。

最后,带有参数 N 的wgmma.wait_group指令会使执行线程等待,直到最近的wgmma组中未完成的数量不超过 N 个,并且执行线程提交的所有先前的wgmma组都已完成。在我们的示例中,我们将 N 设为 0,这样线程束组只需等待整个wgmma组完成,然后再继续执行后续指令。

在线程束组有机会执行独立计算的情况下,参数 N 的灵活性就派上用场了。例如,在 FlashAttention - 3 的设计中采用的 GEMM - softmax 重叠策略就会用到这一点。

WGMMA核心操作

最后这部分将进一步讨论加载到共享内存(SMEM)中的矩阵 A 和矩阵 B 的分块布局要求,假设wgmma的两个操作数均来源于共享内存。为简化讨论,首先假设 A 是按行优先存储,B 是按列优先存储(即两者都是按 K 优先存储)。还要记得,wgmma指令的分块形状 MxNxK 是受限制的,其中 M 为 64,数据类型大小乘以 K 为 32 字节,N 是 8 的倍数,取值范围从 8 到 256。为避免与 A/B 或 sA/sB 混淆,我们将 WGMMA 的原子分块记为 wA 和 wB。

矩阵 wA 和 wB 被划分为许多较小的矩阵,称为核心矩阵。每个核心矩阵都有一个跨步方向和一个连续方向,其在跨步方向上的长度为 8,在连续方向上的长度为 16 字节。矩阵 wA 由 8x2 的核心矩阵组成,矩阵 wB 由 2x(N/8) 的核心矩阵组成。我们通过核心矩阵来展示 wA 和 wB 的分块情况如下(图片取自 PTX 文档):

如上文所述,处于同步流模式(SS 模式)的 wgmma 需要矩阵描述符,即 wA 的描述符(desc-a)和 wB 的描述符(desc-b)作为输入。这种描述符对五个参数进行了编码:

首维字节偏移量(LBO)和跨步步长字节偏移量(SBO)已在上图中标示出来了。

CUTLASS 中的make_gmma_desc方法会根据作为输入提供的共享内存(SMEM)张量的布局来构建描述符(作为GmmaDescriptor的一个实例)。只要输入张量的布局是使用八种规范的通用矩阵乘法(GMMA)布局原子之一以及tile_to_shape来创建的(如之前在 “WGMMA 的共享内存布局约束” 中详细介绍的那样),make_gmma_desc就会准确计算出首维字节偏移量(LBO)和跨步步长字节偏移量(SBO),确定混洗模式,并构建出描述符。例如,GmmaDescriptor描述了在按 K 优先存储的情况下以下可接受的 WGMMA 布局(其中T*sizeof(dtype)=16):

No swizzle    : Swizzle&lt;0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO))
32-byte swizzle : Swizzle&lt;1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T ))
64-byte swizzle : Swizzle&lt;2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T ))
128-byte swizzle : Swizzle&lt;3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T ))

最值得注意的是,对于 64 字节和 128 字节的混洗模式,其步长使得给定的可接受的 WGMMA 布局并非紧凑布局。相反,在 K 方向上会有 2 组或 4 组 WGMMA 原子操作数分块并排堆叠,从而在核心矩阵的 M 模式下产生 4T 和 8T 的步长。换句话说,在混洗时,在内存中会对在 K 模式下逻辑上相邻的 2 个、4 个或 8 个核心矩阵进行交错排列,并且对于 64 字节和 128 字节的混洗模式,这些核心矩阵将属于不同的 WGMMA 原子。

为了内容的完整性,我们也给出在按 MN 优先存储情况下可接受的 WGMMA 布局:

No swizzle    : Swizzle&lt;0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO))
32-byte swizzle : Swizzle&lt;1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO))
64-byte swizzle : Swizzle&lt;2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO))
128-byte swizzle : Swizzle&lt;3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO))

总结

在通用矩阵乘法(GEMM)系列的[第一部分]中,我们探讨了在基于(Hopper)架构的 GEMM 中,将线程束组矩阵乘法与累加(WGMMA)作为基本操作时涉及的核心概念。

WGMMA 需要一个由 128 个线程组成的线程束组来协同执行矩阵乘法,并且只能对矩阵的特定片段进行操作。我们深入探讨了其中涉及的特殊形状和布局,着重介绍了如何使用规范的通用矩阵乘法(GMMA)布局 => 分块转换形状(tile_to_shape)模式来构建确保能被 WGMMA 接受的操作数布局。

为了确保其使用行为明确,WGMMA 还需要特定的同步机制。为此,我们解释了wgmma.fencefence.proxy.asyncwgmma.commit_groupwgmma.wait_groupwgmma.mma_async之间的关联及用途。

最后,我们详细解释了 WGMMA 核心矩阵的内部工作原理,以及 CUTLASS 如何为那些源自共享内存(SMEM)的操作数构建矩阵描述符。

总体而言,这篇博客文章应能让程序员在Hopper架构上编写使用 WGMMA 的 CUTLASS 内核。在[第二部分]中,我们将扩展讨论范围,引入张量内存访问(TMA)技术,以及如何在霍普架构的 GEMM 内核中同时使用 TMA 和 WGMMA,从而实现数据复制和计算的重叠操作。






欢迎光临 链载Ai (https://www.lianzai.com/) Powered by Discuz! X3.5