注意力机制进阶稀疏注意力与线性复杂度优化1. 技术分析1.1 标准注意力的复杂度问题标准 Transformer 的自注意力复杂度为 O(n²d)当序列长度 n 很大时计算成本显著增加注意力复杂度对比 标准注意力: O(n²d) 稀疏注意力: O(nd log n) 线性注意力: O(nd)1.2 注意力机制变体类型复杂度适用场景代表模型标准注意力O(n²d)短序列Transformer稀疏注意力O(nd log n)中等序列Longformer线性注意力O(nd)长序列Linformer滑动窗口O(ndk)局部依赖Transformer-XL1.3 稀疏注意力模式稀疏注意力模式示意图 全局注意力: 每个位置关注所有位置 滑动窗口: 每个位置关注窗口内位置 带状注意力: 对角线带状区域 轴向注意力: 分别处理不同维度2. 核心功能实现2.1 稀疏注意力实现import torch import torch.nn as nn import torch.nn.functional as F class SparseAttention(nn.Module): def __init__(self, d_model, num_heads, window_size512): super().__init__() self.d_model d_model self.num_heads num_heads self.d_k d_model // num_heads self.window_size window_size self.W_q nn.Linear(d_model, d_model) self.W_k nn.Linear(d_model, d_model) self.W_v nn.Linear(d_model, d_model) self.W_o nn.Linear(d_model, d_model) def forward(self, Q, K, V, maskNone): batch_size Q.size(0) seq_len Q.size(1) Q self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) outputs [] for i in range(0, seq_len, self.window_size): end min(i self.window_size, seq_len) Q_window Q[:, :, i:end, :] K_window K[:, :, max(0, i - self.window_size // 2):min(seq_len, i 3 * self.window_size // 2), :] V_window V[:, :, max(0, i - self.window_size // 2):min(seq_len, i 3 * self.window_size // 2), :] scores torch.matmul(Q_window, K_window.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtypetorch.float32)) attn_weights F.softmax(scores, dim-1) output torch.matmul(attn_weights, V_window) outputs.append(output) output torch.cat(outputs, dim2) output output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) output self.W_o(output) return output class LongformerAttention(nn.Module): def __init__(self, d_model, num_heads, window_size512, global_tokens[0]): super().__init__() self.d_model d_model self.num_heads num_heads self.d_k d_model // num_heads self.window_size window_size self.global_tokens global_tokens self.W_q nn.Linear(d_model, d_model) self.W_k nn.Linear(d_model, d_model) self.W_v nn.Linear(d_model, d_model) self.W_o nn.Linear(d_model, d_model) def forward(self, x): batch_size x.size(0) seq_len x.size(1) Q self.W_q(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K self.W_k(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V self.W_v(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) output torch.zeros_like(Q) for i in range(seq_len): if i in self.global_tokens: scores torch.matmul(Q[:, :, i:i1, :], K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtypetorch.float32)) attn_weights F.softmax(scores, dim-1) output[:, :, i:i1, :] torch.matmul(attn_weights, V) else: start max(0, i - self.window_size // 2) end min(seq_len, i self.window_size // 2 1) scores torch.matmul(Q[:, :, i:i1, :], K[:, :, start:end, :].transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtypetorch.float32)) for g in self.global_tokens: if g start and g end: continue global_scores torch.matmul(Q[:, :, i:i1, :], K[:, :, g:g1, :].transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtypetorch.float32)) scores torch.cat([scores[:, :, :, :g-start], global_scores, scores[:, :, :, g-start:]], dim-1) attn_weights F.softmax(scores, dim-1) v_window V[:, :, start:end, :] for g in self.global_tokens: if g start and g end: continue v_window torch.cat([v_window[:, :, :g-start, :], V[:, :, g:g1, :], v_window[:, :, g-start:, :]], dim2) output[:, :, i:i1, :] torch.matmul(attn_weights, v_window) output output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) output self.W_o(output) return output2.2 线性注意力实现class LinearAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.d_model d_model self.num_heads num_heads self.d_k d_model // num_heads self.W_q nn.Linear(d_model, d_model) self.W_k nn.Linear(d_model, d_model) self.W_v nn.Linear(d_model, d_model) self.W_o nn.Linear(d_model, d_model) def forward(self, Q, K, V): batch_size Q.size(0) Q self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) Q F.elu(Q) 1 K F.elu(K) 1 KV torch.einsum(bhld,bhlm-bhdm, K, V) Z 1 / (torch.einsum(bhlq,bhdq-bhl, Q, K.sum(dim2)) 1e-8) output torch.einsum(bhlq,bhdm,bhl-bhlm, Q, KV, Z) output output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) output self.W_o(output) return output class LinformerAttention(nn.Module): def __init__(self, d_model, num_heads, k128): super().__init__() self.d_model d_model self.num_heads num_heads self.d_k d_model // num_heads self.k k self.W_q nn.Linear(d_model, d_model) self.W_k nn.Linear(d_model, d_model) self.W_v nn.Linear(d_model, d_model) self.W_o nn.Linear(d_model, d_model) self.E nn.Parameter(torch.randn(self.k, d_model)) self.F nn.Parameter(torch.randn(self.k, d_model)) def forward(self, Q, K, V): batch_size Q.size(0) seq_len Q.size(1) Q self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K_proj torch.matmul(self.E[:self.k, :self.d_k], K.transpose(-2, -1)).transpose(-2, -1) V_proj torch.matmul(self.F[:self.k, :self.d_k], V.transpose(-2, -1)).transpose(-2, -1) scores torch.matmul(Q, K_proj.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtypetorch.float32)) attn_weights F.softmax(scores, dim-1) output torch.matmul(attn_weights, V_proj) output output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) output self.W_o(output) return output2.3 轴向注意力实现class AxialAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.d_model d_model self.num_heads num_heads self.row_attn MultiHeadAttention(d_model, num_heads) self.col_attn MultiHeadAttention(d_model, num_heads) def forward(self, x): batch_size, height, width, d_model x.size() x_row x.view(batch_size, height, width * d_model) x_row self.row_attn(x_row, x_row, x_row) x_row x_row.view(batch_size, height, width, d_model) x_col x_row.permute(0, 2, 1, 3).contiguous().view(batch_size, width, height * d_model) x_col self.col_attn(x_col, x_col, x_col) x_col x_col.view(batch_size, width, height, d_model).permute(0, 2, 1, 3).contiguous() return x_col3. 性能对比3.1 注意力机制复杂度对比注意力类型时间复杂度空间复杂度最长序列标准注意力O(n²d)O(n²)~1024稀疏注意力O(nd log n)O(nd)~8192线性注意力O(nd)O(nd)~655363.2 不同序列长度的性能序列长度标准注意力稀疏注意力线性注意力512100ms80ms60ms20481600ms200ms100ms8192OOM800ms300ms3.3 效果对比模型精度下降速度提升内存节省Longformer1%4x8xLinformer2%10x16xLinear Attention3%20x32x4. 最佳实践4.1 注意力机制选择def select_attention(sequence_length, task_type): if sequence_length 1024: return MultiHeadAttention elif sequence_length 8192: return LongformerAttention else: return LinformerAttention class AttentionFactory: staticmethod def create(config): if config[type] standard: return MultiHeadAttention(config[d_model], config[num_heads]) elif config[type] sparse: return LongformerAttention(config[d_model], config[num_heads], config[window_size]) elif config[type] linear: return LinearAttention(config[d_model], config[num_heads])4.2 长文本处理class LongTextProcessor: def __init__(self, model, chunk_size512, overlap128): self.model model self.chunk_size chunk_size self.overlap overlap def process(self, text): chunks [] for i in range(0, len(text), self.chunk_size - self.overlap): chunk text[i:i self.chunk_size] chunks.append(chunk) outputs [] for chunk in chunks: outputs.append(self.model(chunk)) return self._merge_outputs(outputs) def _merge_outputs(self, outputs): merged [] for i, output in enumerate(outputs): if i 0: merged.append(output[:-self.overlap]) elif i len(outputs) - 1: merged.append(output[self.overlap:]) else: merged.append(output[self.overlap:-self.overlap]) return torch.cat(merged, dim1)5. 总结注意力机制优化是处理长文本的关键稀疏注意力通过限制注意力范围降低复杂度线性注意力通过核方法实现 O(nd) 复杂度轴向注意力适用于二维结构数据权衡选择根据序列长度和精度要求选择合适方案对比数据如下稀疏注意力可处理 8192 长度序列精度下降 1%线性注意力可处理超长序列但精度略有下降推荐在长文本任务中使用 Longformer 或 Linformer标准注意力在短序列任务中仍然最优