从零开始写Qwen3目录概述经过前文的提速耗时已经从官方的214%降低到112%本文将从汇编角度猜测一下差距的原因概述使用上一节的输入参数设置为BMBN64和torch相同分析汇编指令torch的指令统计如下triton实现的指令统计如下HMMA 是 Half Matrix Multiply Accumulation的意思这是FlashAttn的核心指令使用张量核进行矩阵乘法加速对比两个统计发现不管是指令数量还是实际执行次数都是一样的差别可能在共享内存加载部分指令执行次数分析单条汇编执行次数常见有这么几种数字2048153602048是无循环的执行次数15360是执行循环的次数2048 2 ⏟ B × 16 ⏟ H × 1024 64 ⏟ Q × 4 ⏟ n w a r p s 2048 \underbrace{2}_{B} \times \underbrace{16}_{H}\times \underbrace{\frac{1024}{64}}_{Q}\times \underbrace{4}_{nwarps}2048B2​​×H16​​×Q641024​​​×nwarps4​​可以计算15360 2048 × 16 − 1 2 153602048\times \frac{16-1}{2}153602048×216−1​所以2048是阶段2和公共部分的执行次数15360是阶段2的执行次数阶段2平均循环了7.5次两个阶段指令数基本一致除了因果遮罩那里阶段1没有所以平均执行次数是8704 87048704张量核张量核是CUDA从Volta开始引入的一个指令专门用于矩阵加速它用一条指令让一个线程束一起完成一个小块矩阵乘法不仅简化了矩阵乘法的编写也加快计算速度减少指令发射耗时。张量核仅支持F16最新架构也支持FP8的不支持F32这可能是FlashAttention不支持F32的一个重要原因从PTX汇编来看张量核的关键指令是ldmatrix.sync.aligned.m8n8.x4.shared.b16{%r11,%r12,%r13,%r14},[%r802048];ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16{%r15,%r16,%r21,%r22},[%r89];mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32{%r7,%r8,%r9,%r10},{%r11,%r12,%r13,%r14},{%r15,%r16},{%r7,%r8,%r9,%r10};对应SASS的汇编就是LDSM.16.MT88.4 LDSM.16.M88.4 HMMA.16816.F32PTX是中间指令SASS是实际汇编PTX的可读性比SASS高多了并且有官方文档但它不是最终结果ncu也看不到torch的ptxldmatrix.sync.aligned.m8n8.x4.shared.b16从指令可以看出这是从共享内存读取数据的同步对齐读取m8,n8意思是一次读取8x8的数据b16表示加载的是16bit的数据.x4表示一次性读取4个寄存器也就是4 × 8 × 8 4\times 8\times 84×8×8个数据也有.x2这种指令这个指令是整个线程束协同完成的而且寄存器是32位一个32位存放两个f16这样一个线程束的一个寄存器就存放8 × 8 8\times 88×8条数据8 × 8 32 × 2 1 \frac{8\times 8}{32\times 2}132×28×8​1顺便一提f32转f16的汇编是F2FP.PACK_AB R114, R114, R113明明是单个值转换却有两个输入这其实就是把两个f16打包到一个f32上节省寄存器数量和指令数量ldmatrix.sync.aligned.m8n8.trans.x4.shared.b16就是转置版的应该是列优先实际上mma计算的时候A B ⊤ AB^\topAB⊤的时候反而不需要转置A B ABAB的时候才需要转置mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32这里m16n8k16名字很明确是这样一个乘法D A B ⊤ C C , D ∈ R 16 × 8 ; A ∈ R 16 × 16 ; B ∈ R 8 × 16 D AB^\topC\quad C,D\in \mathbb{R}^{16\times 8};A\in \mathbb{R}^{16\times 16};B\in \mathbb{R}^{8\times 16}DAB⊤CC,D∈R16×8;A∈R16×16;B∈R8×16它有四个操作数分别对应D,A,B,CACD用4个寄存器B用两个C , D : 16 × 8 32 × 1 4 A : 16 × 16 32 × 4 4 B : 16 × 8 32 × 2 2 C,D: \frac{16\times 8}{32\times 1}4\\ A: \frac{16\times 16}{32 \times 4}4\\ B: \frac{16\times 8}{32 \times 2}2C,D:32×116×8​4A:32×416×16​4B:32×216×8​2张量核指令数量计算n_warps1K16的情况此时没有累加没有分线程束所以读取一次只会用于计算一次B每次只用一半所以是两次得到A加载次数是M/16B加载次数是N/16计算次数是M N 16 × 8 \frac{MN}{16\times 8}16×8MN​n_warps1K16的情况要么固定A要么固定B把另一个并行比如2并行把B并行就是这样A [ B 0 B 1 ] A\left[\begin{matrix} B_0\\B_1\end{matrix}\right]A[B0​B1​​]所以加载次数是M 16 N 16 × n \frac{M}{16}\frac{N}{16\times n}16M​16×nN​计算次数简单除以nM N 16 × 8 × n \frac{MN}{16\times 8 \times n}16×8×nMN​n_warps1K≠16的情况此时必须有累加加载次数没有影响M K 16 × 16 N K 16 × 16 \frac{MK}{16\times 16}\frac{NK}{16\times 16}16×16MK​16×16NK​计算增加M N K 16 × 16 × 8 \frac{MNK}{16\times 16\times 8}16×16×8MNK​n_warps≠1,K≠16的情况此时并行就需要注意累加必须在同一个线程束中所以虽然划分方向多了一个但不能同时划分A和B或者同时划分行列只能还是按照行划分加载次数和K16的情况一致计算次数按照上面计算差异分析torch的指令数量分析64x128和64x128的乘积num_warps4加载次数64 × 128 16 × 16 × 4 64 × 128 16 × 16 8 32 40 \frac{64\times 128}{16\times 16\times 4}\frac{64\times 128}{16\times 16}8324016×16×464×128​16×1664×128​83240计算次数64 × 128 × 64 16 × 8 × 16 × 4 64 \frac{64\times 128\times 64}{16\times 8\times 16 \times 4}6416×8×16×464×128×64​64观察发现torch是对Q并行而简单的triton是对K并行然后计算attn V这个过程attn没有加载直接用寄存器V则用的是 LDSM.xxx.trans 版本加载次数简单除以大小和并行attn V的计算次数和QK^top一致都是64所以torch的FlashAttnV2中有40条LDSM32条LDSM.trans显然是左侧并行和QK^top一样128条HMMA但由于做了2阶段所以全部乘280条LDSM64条LDSM.trans256条HMMAtriton指令数量分析triton实现把attn存到共享内存然后又加载出来加载次数计算就是64 × 64 16 × 16 16 \frac{64\times 64}{16\times 16}1616×1664×64​16这样就在基础的40上又增加16条2倍就是112条然后triton是V并行V的LDSM.trans加载64 × 128 16 × 16 × 4 8 \frac{64\times 128}{16\times 16\times 4}816×16×464×128​82倍就是16条查了一圈triton好像tl.dot要强制加载共享内存不能直接由寄存器计算这里可能有一些代价