训练时用 int8 权重和激活能省 50% 显存、提速 1.8×。但训练是数值敏感的直接把权重强行转 int8 梯度会崩。CANN 的 QAT-W8A8 方案在 forward 时用 int8 计算backward 时用 fp16 伪量化——梯度还是 fp16 的精度。原理伪量化Fake QuantizationForward: fp16 → quantize to int8 → compute → dequantize to fp16 Backward: fp16 gradient不量化梯度Forward 模拟量化误差让模型适应量化。Backward 用 fp16 梯度保证收敛性。实现QAT W8A8importtorchfromtorch_npu.contribimportQATW8A8 modelAutoModelForCausalLM.from_pretrained(meta-llama/Llama-2-7b-hf,torch_dtypetorch.bfloat16,device_mapnpu:0,)# 包装成 QAT W8A8 模型qat_modelQATW8A8(model,weight_bits8,activation_bits8,calib_dataloadercalib_dataloader,# 校准数据集统计激活分布)# 正常训练optimizertorch.optim.AdamW(qat_model.parameters(),lr1e-5)fordataindataloader:lossqat_model(data)loss.backward()optimizer.step()# 训练完成后转成真正量化模型quant_modeltorch.ao.quantization.convert(qat_model)torch.save(quant_model.state_dict(),model_w8a8_quant.pt)校准数据集W8A8 需要校准数据集来统计激活的分布min/max 或 percentile。校准集要跟训练/推理数据同分布。# 用训练集的前 500 条做校准calib_dataloaderDataLoader(train_dataset.select(range(500)),batch_size4,shuffleFalse,)校准集太小100 条→ 激活分布统计不准量化误差大。校准集太大2000 条→ 校准时间长30-60 分钟。精度损失Llama2-7BCANN 8.5Atlas 800I A2量化方案WNLI (准确率)GSM8K (准确率)训练速度fp16 (基准)78.5%56.2%1.0×QAT W8A877.8% (-0.7%)55.1% (-1.1%)1.7×PTQ W8A876.1% (-2.4%)53.8% (-2.4%)- (推理 1.8×)QAT W8A8 的精度损失只有 PTQ W8A8 的 1/3。训练速度提升 70%显存省了batch 可以开更大。显存节省Llama2-7B 训练显存配置权重 (GB)梯度 (GB)优化器状态 (GB)激活 (GB)总计 (GB)fp161414282076W8A8 QAT77142048显存从 76GB 降到 48GB。单卡 64GB 能跑不需要 8 卡 TP。跟 LoRA 的配合QAT W8A8 和 LoRA 可以一起用frompeftimportLoraConfig,get_peft_modelfromtorch_npu.contribimportQATW8A8# 先加 LoRAmodelget_peft_model(model,lora_config)# 再包装 QAT W8A8qat_modelQATW8A8(model,weight_bits8,activation_bits8)LoRA 参数用 fp16 训练参数量小量化收益低基座参数用 int8 训练参数量大量化收益高。推理部署训练好的 QAT W8A8 模型推理时直接用 int8 GEMMfromatbimportLLM modelLLM(model_w8a8_quant.pt,devicenpu:0,quantizew8a8_qat,)ATB 内部调用 int8 GEMM kernel吞吐是 fp16 的 1.8×。跟 AOE 的配合QAT W8A8 的 int8 GEMM Tiling 参数也可以用 AOE 调优aoe--job_type2\--model_pathmodel_w8a8.onnx\--configaoe_config_w8a8.jsonint8 GEMM 的 Tiling 搜索空间比 fp16 小因为 int8 的 Cube 分块大小固定调优时间约 30 分钟。QAT W8A8 是训练时量化的最佳实践——forward 用 int8 提速backward 用 fp16 保精度。显存省 37%训练速度提 70%精度损失 1%。仓库在这里https://atomgit.com/cann/torch_npuhttps://atomgit.com/cann/AMCT