Transformer相比RNN(循环神经网络)的核心优势之一是天然支持 并行计算,这源于其自注意力机制和网络结构的设计.并行计算能力和长序列处理瓶颈是其架构特性的两个关键表现:
并行计算:指 Transformer 在训练 / 推理时通过矩阵运算并行化、模块独立性实现高效计算的能力;
长序列处理瓶颈:指当输入序列长度(n
)增加时,自注意力机制的计算 / 内存复杂度呈O(n²)
增长,导致效率骤降的问题。
并行计算
1.自注意力机制的并行性
自注意力的计算公式为:
对于序列长度为的输入,自注意力中每个位置的计算不依赖其他位置的中间结果:
- 计算的线性变换时,所有token的可同时生成(并行);计算(的分数矩阵)时,每个元素的计算独立于其他元素(可并行);即使是softmax和加权求和步骤,也可对整个序列的所有位置同时执行(并行)。
而RNN需要按序列顺序计算(依赖),完全串行,无法并行。
2. 网络结构的并行性
- 编码器 / 解码器 层的 并行:编码器的每一层(多头注意力+前馈网络)对整个序列的处理是“批量”的,所有token共享层参数,可同时更新;训练时的 并行 优化:结合数据并行(同一模型在不同样本上并行训练)、模型并行(将网络层拆分到不同设备),可充分利用GPU/TPU的并行计算能力,大幅加速训练。
核心观点:Transformer的并行能力源于模块独立性和矩阵运算的可并行性。
底层:矩阵运算天然支持并行(GPU的SIMD架构可并行处理矩阵元素);
中层:模块独立(前馈网络对每个位置的计算独立;多头注意力的“头”之间无依赖);
顶层:训练时可通过批处理(batch维度)、序列分片进一步提升并行效率。
根本原理:并行能力源于“计算单元的独立性”和“矩阵运算的可拆分性”。
前馈网络:对序列中每个位置的计算是独立函数(FFN(x_i) = W2·ReLU(W1·x_i + b1) + b2
),无跨位置依赖,可完全并行;
多头注意力:每个“头”的计算独立(head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
),头之间可并行;
矩阵运算:QK^T
的每个元素(QK^T)[i][j] = Q[i]·K[j]
,元素间无依赖,可由GPU并行计算。
长序列瓶颈
长序列处理的核心瓶颈
当序列长度增大(如文档级文本、长视频帧、基因组序列,可达甚至),Transformer的性能会急剧下降,核心瓶颈来自自注意力的复杂度:
1. 计算复杂度瓶颈
自注意力的核心步骤(矩阵乘法)的计算量为(为隐藏层维度):
- 当时,计算量约为;当时,计算量增至(是前者的100倍)。
这种平方级增长会导致:
- 单次前向/反向传播时间大幅增加(训练/推理变慢);难以利用并行计算优势(过多计算量超出硬件算力上限)。
2. 内存瓶颈
自注意力过程中需要存储多个或的中间张量:
- 的形状为,总内存为;QK^的分数矩阵形状为,内存为;注意力权重矩阵(softmax结果)同样为,内存。
当时,,若每个元素为4字节(float32),仅分数矩阵就需要400MB内存,加上其他张量,单头注意力就可能占用数GB内存,远超普通GPU的显存上限(如16GB GPU难以处理的序列)。
3. 优化器的额外负担
训练时,优化器(如Adam)需要存储所有参数的梯度和动量信息,长序列会导致中间变量(如注意力权重的梯度)的内存占用也随增长,进一步加剧内存压力。
三、长序列处理的解决方案
为突破瓶颈,研究者提出了多种优化思路,核心是用“稀疏注意力”或“线性复杂度注意力”替代全局注意力:
1. 稀疏注意力(Sparse Attention)
仅计算部分位置的注意力,将复杂度降至(为局部窗口大小):
- 滑动窗口注意力(如Longformer):每个位置仅关注左右个相邻位置(总窗口),适合时序相关的长序列;固定稀疏模式(如BigBird):每个位置关注“局部窗口+随机采样+全局标记”,兼顾局部相关性和全局信息;轴向注意力(如Axial Transformer):将长序列拆分为多个维度(如文本拆分为“句-词”),在每个维度单独计算注意力,复杂度降至。
2. 线性注意力(Linear Attention)
用“核函数”替换的矩阵乘法,将复杂度降至:
- 核心思路:将改写为(为归一化项),通过核函数(如)的性质,将矩阵乘法转化为逐元素操作;代表模型:Performer(用随机特征映射近似核函数)、Linformer(用低秩矩阵近似)。
3. 分层/压缩注意力
通过“序列压缩”减少有效长度:
** hierarchical Attention**:先对长序列分块,计算块内注意力得到“块表示”,再计算块间注意力(如文档先分句子,再对句子表示计算注意力);
Downsampling:用池化(如平均池化)或卷积将长序列压缩为短序列(如ViT中的Patch Embedding将图像压缩为的patch序列)。
核心观点:长序列处理瓶颈源于自注意力的全连接关联特性,导致复杂度随长度平方增长。分层展开:
底层:自注意力需计算“每个位置与所有位置”的关联(QK^T
矩阵为n×n
);
中层:计算复杂度O(n²d)
(d
为隐藏维度)、内存占用O(n²)
(存储注意力权重);
顶层:当n
过大(如n>10k
),计算耗时、内存溢出,效率骤降。
根本原理:自注意力的“全关联定义”导致复杂度随长度平方增长,是机制固有属性。
自注意力的核心公式为:
Attention(Q,K,V) = softmax((QK^T)/√d_k)·V
其中QK^T
是n×n
矩阵(n
为序列长度),其计算/存储复杂度必然是O(n²)
;即使优化实现(如稀疏化),也只能降低系数,无法改变O(n²)
的本质(因“注意力”定义本身要求衡量位置间的关联)。