DeepSeek V3/V2 Sparse Flash Attention with Quantization【免费下载链接】cann-outreach项目地址: https://gitcode.com/cann/cann-outreach基于 PyPTO 实现的 DeepSeek V3/V2 MLAMulti-head Latent Attention稀疏注意力算子支持 KV Cache INT8 量化和 PagedAttention运行于华为昇腾 NPU。算法概述本算子实现了 DeepSeek V3/V2 架构中的稀疏注意力机制核心特征如下MLA 压缩表示Query/Key 被拆分为nope低秩压缩部分kv_lora_rank512和rope旋转位置编码部分qk_rope_dim64两个子向量拼接后参与注意力计算。稀疏 Top-K 选择每个 query token 仅关注从 KV Cache 中选出的topk个 key-value 对默认 topk2048而非全序列。KV Cache INT8 量化Key 的nope部分支持 INT8 逐 block 量化128 元素为一组配合 FP32 scale 反量化后参与计算。PagedAttentionKV Cache 以 blockblock_size128为粒度管理通过block_table映射物理位置支持不连续内存布局。变长序列每个 batch 可有不同的实际序列长度actual_seq。算子规格项目说明算子名称sparse_flash_attention_quant数据类型BF16 (query/key_nope/key_rope/output), INT8 (可选 key_nope), FP32 (scales)精度标准rtol0.005, atol0.0001动态轴query_nope/query_rope/topk_indices/block_table/kv_act_seqs 的首维为动态推理模式Decode (s11/2, 标准 softmax) / Prefill (s1256, Flash online softmax)输入输出参数方向shapedtype说明query_nope输入(BS1N_Q, kv_lora_rank)BF16Query 低秩压缩部分query_rope输入(BS1N_Q, qk_rope_dim)BF16Query 旋转位置编码部分key_nope_2d输入(block_num*block_size, kv_lora_rank)BF16 / INT8Key 低秩压缩部分 (KV Cache)key_rope_2d输入(block_num*block_size, qk_rope_dim)BF16Key 旋转位置编码部分 (KV Cache)k_nope_scales输入(block_num*block_size, 4)FP32Key INT8 反量化 scale (每 128 元素一组)topk_indices输入(BS1, N_KVtopk)INT32每个 query token 的 top-k 索引block_table输入(B, max_blocknum_perbatch)INT32PagedAttention block 映射表kv_act_seqs输入(B,)INT32每个 batch 的实际 KV 序列长度attention_out输出(B, S1, N_Q, kv_lora_rank)BF16注意力计算结果文件结构attention/ ├── deepseekv32_sparse_flash_attention_quant.py # 测试入口与 golden 生成 ├── sparse_flash_attention_quant_impl.py # PyPTO kernel 实现 ├── README.md └── utils/ └── compare.py # 精度对比工具实现版本函数模式算法适用芯片sparse_flash_attention_quant_dDecode标准 softmax910Bsparse_flash_attention_quant_d_950Decode标准 softmax950sparse_flash_attention_quant_pPrefillFlash Attentiononline softmax910BDecode 模式sparse_flash_attention_quant_compute使用标准 softmax 归一化softmax exp(S - max(S)) / sum(exp(S - max(S)))每次 s2 tile 计算后直接得到归一化结果写入输出适用于 s1 较小如 s11 或 s12的 decode 场景Prefill 模式sparse_flash_attention_quant_compute_flash使用 Flash Attention 算法的 online softmax维护oi_update累加输出、li_update累加归一化因子、mi_update累加最大值三个运行状态跨 s2 tile 增量更新mi_new max(mi, tilda_mij)→ 修正历史累加值 → 归一化仅在最后一个 s2 tile 时做最终归一化减少中间精度损失适用于 s1 较大如 s1256的 prefill 场景实现要点1. 计算流水线每个 s2 tile对每个 batch、每个 s1 token、每个 KV head 组: ├─ Sa_V0: Gather — 从 KV Cache 按 topk_indices 搬运 Key/Value 数据 │ ├─ 若 INT8 量化: Gather INT8 kn scales → 反量化 → BF16 │ └─ 若 BF16: 直接 Gather BF16 kn ├─ Sa_C1: S Q × K^T (BF16 → FP32 matmul) ├─ Sa_V1: Softmax(S * scale) │ ├─ Decode: 标准 softmax (exp-max / sum) │ └─ Prefill: Flash online softmax (exp-max, 不除 sum, 累积 oi/li/mi) ├─ Sa_C2: O Softmax × V (BF16 matmul) └─ Sa_V2: Flash 归一化更新 (仅 Prefill 模式, 最后一个 tile 时 O oi / li)2. 涉及的 PyPTO API流程控制API用途pypto.frontend.jitKernel JIT 编译装饰器配置 pass_options / runtime_optionspypto.loop生成硬件级循环batch / s1 / n_kv / group / s2pypto.loop_unroll循环展开s2 tile 维度pypto.cond条件分支首个/末个 tile 判断张量构造与视图API用途pypto.view创建张量视图切片 topk_indices / block_table / query / keypypto.reshape张量形状变换INT8 反量化 reshape 对齐pypto.concat张量拼接扩展 INT8 列宽用于 reshapepypto.assemble将子张量写入目标偏移位置拼接 Key/Query, 写回输出计算算子API用途pypto.matmul矩阵乘法C1(Q×K^T) 和 C2(Softmax×V)pypto.amax沿指定维度求最大值Softmax 数值稳定减最大值防溢出pypto.exp逐元素指数运算Softmax 核心pypto.sum沿指定维度求和Softmax 归一化因子pypto.maximum逐元素取最大值Flash Attention: mi_new max(mi, tilda_mij)pypto.mul逐元素乘法scale × S, INT8 反量化, Flash 增量修正pypto.sub逐元素减法S - max, Flash 增量修正因子pypto.add逐元素加法Flash 增量更新 li_new, oi_newpypto.div逐元素除法Softmax 归一化 / Flash 最终归一化 O oi / lipypto.cast数据类型转换INT8→FP16→FP32→BF16 反量化链路pypto.gather按 topk_indices 从 KV Cache 搬运数据gather_in_ub / gather_in_l1 的底层依赖Tiling 与编译配置API用途pypto.set_vec_tile_shapes设置向量算子 tile 尺寸Gather / Softmax / Flash 更新pypto.set_cube_tile_shapes设置 Cube matmul tile 尺寸C1: Q×K^T, C2: Softmax×Vpypto.set_matrix_size设置 matmul 矩阵尺寸 [M, K, N]pypto.set_semantic_label设置语义标签Sa_V0 / Sa_C1 / Sa_V1 / Sa_C2 / Sa_V2pypto.set_pass_options编译期 Pass 选项BF16 路径 scope 隔离3. 关键设计决策5 层嵌套循环结构Decode 与 Prefill 共享同一循环骨架L0 batch→pypto.loop, Decode 可并行 (parallelTrue), Prefill 串行L1 s1→pypto.loop, query 序列维度L2 n_kv→pypto.loop, KV head 维度 (GQA)L3 group→pypto.loop, GQA group 维度 (N_Q / N_KV 128)L4 s2→pypto.loop_unroll, KV 序列 tile 维度 (unroll_list{1})INT8 量化路径Keynope部分按每 128 元素分组量化512 / 128 4 组gather_in_ub分别搬运 INT8 kn 和 FP32 scalesINT8 → FP16 → FP32 类型提升链逐组乘 scale 完成反量化转回 BF16 参与后续计算Flash Online Softmax 增量更新Prefill 模式跨 s2 tile 维护三个运行状态oi_update: 累积注意力输出未归一化li_update: 累积 exp sum, shape(1, group_tile)mi_update: 累积 max, shape(1, group_tile)每个后续 tile 的修正公式mi_new max(mi, tilda_mij) li_new exp(mi - mi_new) * li exp(tilda_mij - mi_new) * tilda_lij oi_new exp(mi - mi_new) * oi exp(tilda_mij - mi_new) * q1仅在最后一个 tile 时做最终归一化O oi / li。Tiling 配置通过SaTileShapeConfig控制各级计算的 tile 参数dataclass class SaTileShapeConfig: g_tile: int # GQA group tile 大小 s_kv_tile: int # KV 序列维度 tile 大小 gather_vec_tile_shape: list # Gather 向量算子 tile c1_tile_shape: list # C1Q×K^TCube 算子 tile [M0,M1, K0,K1, N0,N1] v1_tile_shape: list # V1Softmax向量算子 tile c2_tile_shape: list # C2Attn×VCube 算子 tile [M0,M1, K0,K1, N0,N1] v2_tile_shape: list # V2Flash 归一化更新向量算子 tile仅 prefill 使用不同芯片/模式的默认配置参数910B Decode910B Prefill950 Decodeg_tile128128128s_kv_tile204820482048gather_vec_tile_shape[32, 512][32, 512][64, 512]c1_tile_shape[128,128,128,128,128,128][128,128,128,128,128,128][128,128,128,128,64,64]v1_tile_shape[8, 2048][8, 2048][4, 2048]c2_tile_shape[128,128,128,128,128,128][128,128,128,128,128,128][128,128,128,128,128,128]v2_tile_shape[64, 256][64, 128][64, 256]测试用例用例名BS1N_QN_KV序列长度Key 量化模式芯片sfa_bf16_b4_s2_seq64K_total_int8_d421281[65536, 16381, 666, 15]INT8Decode910B/950sfa_bf16_b4_s2_seq64K_per_int8_d421281[65536]×4INT8Decode910B/950sfa_bf16_b4_s2_seq64K_per_bf16_d421281[65536]×4BF16Decode910B/950sfa_bf16_b1_s256_seq64K_int8_p12561281[65536]INT8Prefill910Bsfa_bf16_b4_s2_seq64K_per_int8_d_950421281[65536]×4BF16Decode950精度容差atol0.0001, rtol0.005。运行方式环境准备# 配置 CANN 环境变量 source /usr/local/Ascend/ascend-toolkit/set_env.sh # 设置设备 ID export TILE_FWK_DEVICE_ID0通过 pytest 运行# 运行默认用例 pytest deepseekv32_sparse_flash_attention_quant.py -v # 运行指定用例 pytest deepseekv32_sparse_flash_attention_quant.py::test_sfa_bf16_b4_s2_seq64k_total_int8_d -v # 指定芯片类型 pytest deepseekv32_sparse_flash_attention_quant.py -v --soc910B直接运行python deepseekv32_sparse_flash_attention_quant.py默认执行test_sfa_bf16_b4_s2_seq64k_per_int8_d()。可编辑__main__部分切换用例。关键依赖pyptoPyPTO 算子开发框架pypto.experimental.gather_in_ub/gather_in_l1PagedAttention block 级稀疏 Gathertorch/torch_npuPyTorch 及昇腾 NPU 后端numpy数据生成辅助pytest测试框架环境要求CANN 工具链Ascend 910B / 950PRPyPTO 框架PyTorch torch_npu环境变量TILE_FWK_DEVICE_ID已设置如export TILE_FWK_DEVICE_ID0注意事项环境要求需要可用 NPU 环境npu-smi info可检测到设备TILE_FWK_DEVICE_ID环境变量可指定设备编号默认为 0。JIT 编译三个 kernel 入口函数通过pypto.frontend.jit装饰器注册首次调用会触发编译。INT8 量化规则Keynope部分按每 128 个元素分组求绝对最大值作为 scale量化到[-128, 127]范围。Gather 算子使用pypto.experimental.gather_in_ub/gather_in_l1实现 PagedAttention 的 block 级稀疏索引。Tiling 调优算子性能高度依赖于set_vec_tile_shapes/set_cube_tile_shapes的设置950 芯片因 UB/L1 容量不同需使用独立的 TileShape 配置。DYNAMIC Loops2 tile 维度使用pypto.loop_unroll(unroll_list{1})首次调用时确定循环次数并编译后续调用可使用更少的迭代次数但不可超出首次编译时的循环次数。debug_optionssparse_flash_attention_quant_d910B Decode开启了runtime_debug_mode和compile_debug_mode正式发布时应移除。【免费下载链接】cann-outreach项目地址: https://gitcode.com/cann/cann-outreach创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考