发展历程#
起源#
稀疏注意力(Sparse Attention)是一种优化的注意力机制,它可以将一个查询向量和一组键值对映射到一个输出向量,但与单头注意力和多头注意力不同的是,它不会计算查询向量和所有键向量的相似度,而是只计算查询向量和部分键向量的相似度,从而减少计算量和内存消耗。稀疏注意力的概念最早出现在 2018 年的论文《Generating Long Sequences with Sparse Transformers》中,该论文提出了一种基于 Transformer 的长序列生成模型,其中使用了稀疏注意力来处理超过 8000 个单词的文本序列。
实际上在训练好的 Transformer 模型中,注意力矩阵往往是稀疏的,这意味着并不是每个令牌都需要关注其他所有令牌。有些令牌之间的相互作用可能对最终的输出贡献不大,可以被忽略。
稀疏的方式可以是固定的模式(如局部窗口)、基于内容的选择(如与当前位置最相关的其他位置),或者是通过学习得到的模式。
根据确定稀疏连接的度量标准,方法分为两类:基于位置的稀疏注意力和基于内容的稀疏注意力。
基于位置的稀疏注意力#
-
Global Attention
设置了全局节点的概念,在这些全局节点上,一个节点能关注其他所有节点,换种理解方式也就是:这些节点充当了所有节点相互交流的中转站,稀疏注意力的本质就是没必要点对点的每个节点都去交流,有点像是 P2P 到 P2S 的概念
-
Band Attention
考虑数据分布的局部性质,也就是类似于滑动窗口的概念,也就是一个节点只关注它周围的节点,将注意力的交互限制在局部注意力中 (Attention 好不容易有了的长距离感受野被你这么一截断就坏了)
-
Dilated Attention
在 Band Attention 的基础上,通过留白给自己设置一个更远的交互节点,相当于是在滑动窗口上给自己设置一个步长间隔 (截了一段发现效果不好又想着延长一下,不过感觉依然很拙劣)
-
Random Attention
为每个查询随机抽样一些边来实现,感觉纯随机的中奖,效果应该不咋样
-
Block Attention
将输入序列划分为若干个不重叠的查询块(query blocks),并为每个查询块分配一个局部记忆块(memory block),来实现对长序列的高效处理。(没太懂这个玩意,这不就是 Attention 把长和宽都变小了一号吗🤓)
基于内容的稀疏注意力#
-
最大内积搜索(MIPS)
为了高效地构建基于内容的稀疏图,可以利用最大内积搜索(Maximum Inner Product Search,MIPS)问题的解决方案。MIPS 的目标是找到与查询具有最大点积(dot product)的键,而不需要计算查询与所有键之间的点积
NSA (推理效率和训练可行)#
起源#
长文本建模对于大模型来说是及其重要的,但是传统的注意力机制是平方的运算复杂度,增加 context length 会增加大量的额外计算,长度增加两倍,计算量会增加三倍
最先进的稀疏注意力分成了两类,分别是 KV-cache eviction 和 block KV-cache 的选择,采样,聚类的方法。但是这两类方法都没有他们说得那么好。
主要是以下几个方面的问题: 局部稀疏,不适配 Attention,端到端训练
- 局部稀疏:H2O 这种方法只在自回归 decode 阶段应用了稀疏矩阵,但是 prefill 中需要计算密集型预处理。MInference 则只在 prefill 时采用了稀疏注意力。这些方法都没有在所有阶段实现稀疏注意力,那么对于 prefill 主导的工作比如书籍摘要,代码补全或者 decode 主导的工作比如思维链上就会表现不好,也就是针对下游任务没有一个统一的架构来实现端到端训练。
- 不适配 Attention: 大部分稀疏矩阵考虑的是对 MHA 的稀疏,而对于 MQA 和 GQA 这样的结构,会有不适配的情况,比如 Quest 方法,每一个注意力头都有它独立的 kv-cache,然而对于 MQA 和 GQA 这样共用
- 端到端训练: 现在大多数稀疏矩阵都是针对的推理任务,需要一个针对训练任务的稀疏矩阵,但是基于 dense 矩阵训练的模型在稀疏推理下表现不佳,因为 20% 的注意力只能 cover70% 的 Attention Score。更有甚者像是 ClusterKV 和 MagicPIG 这样的工作引入了不连续的计算图,从而导致反向传播无法正常执行。非连续的内存访问阻止了对 FlashAttention 等快速注意技术的有效适应,这些技术依赖于连续的内存访问和分块计算来实现高吞吐量。
NSA 的报告中提到了主要要解决的是两个问题:
- 一是与硬件联合的推理优化,在 prefill 和 decode 两个阶段,将理论优化变为实际加速需要硬件友好型的算法,主要是在内存访问和硬件瓶颈调度上
- 二是训练感知算法设计,支持端到端学习稀疏模式,避免传统方法「先训练后裁剪」的性能损失
提到了解决的方法,主要分为三步:
-
compressed coarse-grained tokens(cmp)
将连续的 key/value 块聚合为块级别的表示,捕获粗粒度的语义信息,减少计算负担。
通俗来讲就是把 kv 的多个维度融于一个维度,举个例子: 1024 维的 kv 变为 64 维度,
-
selectively retained fine-grained tokens(slc)
有选择地保留重要的 token,弥补压缩可能带来的信息损失。
通俗来讲就是类似于 MIPS 的寻找最相关的 token 注意力,其余 token 不值得关注
-
sliding windows(win)
专门处理局部上下文的滑动窗口分支,解决局部模式可能主导学习过程的问题。
通俗来讲就是上面的 Band Attention (坏了它还真有用 🤯)
Demo#
我们可以举一个简单的例子来说明一下这个过程:
我们现在有的输入是 ,假设 ,然后假设按照长度为 8 进行 ,由于 是对称的,我们用 举例,即可把 各分为 8 块 ,经过压缩后,也就是把 变为 和 一样 size 的向量块,通俗来讲就是把很多块 变为一个 来减少 所占的显存并加速计算,此时用原始的 和压缩后的 进行注意力分数计算得到压缩注意力 。
中间部分叫做 ,在压缩时我们得到了压缩后的 KV 块 ,此时计算最大的几个注意力分数,我们这里选择 的,假设为 ,也就是第三块和第七块,此时还原对应的选出来的压缩块,也就是将 扩展回到 这样去处理拿到我们所需要的 块,然后计算得到选择注意力 。
右边是滑动窗口,在原 中选择最近的 8 个 就可以得到滑动窗口注意力 。
最后再用一个门控函数控制,即
分析一下节省的 ,原本有 64 个 ,我们的压缩注意力用到了 8 个 ,选择注意力用到了 16 个 ,滑动窗口注意力用到了 8 个 ,相当于现在一共只用到了 32 个 ,节省了一半的 显存。
背景#
Attention#
对于一个新来的查询 需要查询之前所有的 t 个 对
算术强度#
:访问内存时间等于内存中访问的字节数除以处理器的内存带宽。
:数学时间等于运算次数除以处理器的数学带宽。
如果 ,那么该算法是受数学限制的
上式可以被替换为 ,左边是算法实现操作数与访问字节数的比值,被称为算法的算术强度,右边是处理器的数学带宽与内存带宽的比值,被称为字节比率
- 如果算法的算术强度高于 GPU 的 字节比率,那么该算法受算力限制的,也称
math bound
,即性能受算力FLOPS
限制(算力受限 / 计算密集型算子)。 - 如果算法的算术强度低于 GPU 的 字节比率,则该算法受内存限制,也称
memory bound
,即性能受内存带宽限制(内存受限 / 访存密集型算子)
应该尽可能让算法 / 网络层的算术强度高于 GPU 的字节比率,这样才能充分利用
gpu
的算力
在 prefill 阶段,大量的 causal self-attention 所展现出的批量矩阵乘法展现出了高算术强度,即性能受到算力限制。在自回归的 decode 阶段,因为其每生成一个 token 都需要访问之前所有的 kv-cache,则变为受到内存带宽限制。而这种差异性会带来优化方向的不一致,在 prefill 和 train 的时候减少计算复杂度,在 decode 阶段减少内存访问
方法#
两个方向:算法侧设计和 kernal 优化
整体概况#
将原始的 优化成更加 compact 和 information-dense 的 ,其中这个变换是基于 动态改变的,用公式表达就是:
对于函数映射,一共有三种方式,也就是上面所说的 cmp,slc,win 三种方式,通过一个门控因子来控制采用哪种映射,具体如下所示:
compressed coarse-grained tokens (压缩)#
其中 是块的长度, 是块间滑动的步长, 是一个可学习的 MLP,将块中的键映射到一个压缩键。文中提到通常来讲这个 是要小于 的,来缓解信息碎片化。
原文说压缩表示可以捕获更粗粒度的高级语义信息,并减轻注意力的计算负担,说是啥就是啥吧,实验出来结果好就是王 (🤓)
selectively retained fine-grained tokens (选择)#
只使用粗粒度的上述 token 势必是不行的,丢失了大量细粒度的信息,我们还需要细粒度的块来帮助模型更好地理解。
块的选择:
基于硬件友好的考虑和注意力分数的固定分布。这一步对于在现代 GPU 上实现高效的计算至关重要。现代 GPU 在连续块访问上的吞吐量远胜于基于随机索引的读取,同时块计算也可以最大效率利用 GPU 的 tensor core。注意力分数通常表现出空间连续性,这表明相邻的键往往具有相似的重要性水平,这是 DS 后面做实验发现的,浅色区域表示较高的关注值,如图所示,
重要注意力分数计算:
计算所有的注意力分数显然是一个开销很大的事,但是我们可以通过计算前一步压缩后的注意力来降低这个开销
但上面这个只是基于压缩的注意力分数,普遍意义上,我们需要的选择块长度定义为 ,当 时, ,对于分块不一致的情况,给定 ,那么
但其实 和 是不一样的,在 NSA 方案里面,让他们一致。在 GQA 和 MQA 中,对于不同的且共享同样的 KV 值的 Q head ,他们的重要注意力分数是一样的,就是把所有的注意力分数相加作为这个 KV 的注意力分数,这一步直接可以节省大量内存。
选择最大的 k 个注意力分数:
选择最大的 k 个注意力分数,注意到这里我们的选出的是压缩块,也就是 , 表示排名
基于压缩块再复原最开始的所有 $k$ 块
sliding windows (滑动窗口)#
在注意力机制中,局部模式通常适应得更快,并且可以主导学习过程,这可能会阻止模型从前两种 kv 中有效学习。为了解决这个问题,引入了一个专用的滑动窗口分支,它显式地处理原始上下文,允许其他分支(压缩和选择)专注于学习它们各自的功能。
三个分支分别提供了独立的键和值。这种架构设计通过防止局部和全局之间的梯度干扰来实现稳定的学习,同时引入最小的开销。获取 六个 KV 值后,采用 gate 门控的方式从中获取结果。
Kernel Design#
要在 train 和 prefill 阶段实现 FlashAttention 级别的加速,利用 Triton 实现了和硬件对齐的稀疏矩阵
在压缩阶段和滑动窗口都可以很好利用 FlashAttention 进行优化,因此这里提到的 kernel 优化主要是针对选择阶段所产生的离散注意力序列的计算。
针对 GQA 和 MQA 进行优化,如果我们遵循 FlashAttention 的策略,将时间连续的查询块加载到 SRAM 中,这将导致内存访问效率低下,因为块内的查询可能需要不相交的 KV 块,为了解决这个问题,将 GQA 中所有共享相同 kv 块的查询头一起加载到 SRAM 中。
组中心数据加载:
在内循环中,加载有 个头的查询 ,然后找到刚才他属于的被压缩的 index
共享 KV:
在内循环中,根据 选入此时需要的 ,其中 是满足 最小的 kernel 块大小
以上的一个绿色块代表了一个 q 和一段 kv 计算。这里我们注意到,当 t 增大时,由于我们选择的 KV 块恒为小于等于 3 块,那么越长 NSA 的加速越明显。
MoBA (大道至简,即插即用)#
背景#
对于稀疏性,不仅提到了注意力分数的稀疏性,还提到了与记忆存储相关的脑区观察到的稀疏连接特性 (感觉可以投 ACL🤓)
传统的稀疏注意力有两大缺陷,一是采用预定义的结构,基于一些特定的任务,泛化性很差,二是动态选取 Token 进行稀疏注意力训练的方式,这个方法一般对于训练阶段毫无帮助。
与 NSA 要解决的推理速度和可训练两个问题差不多,Moba 解决的问题也是加速推理以及可训练,块注意力混合机制 (MoBA,Mixture of Block Attention),将专家混合 (MoE) 从 MLP 迁移到 Attention
核心方法是了针对每个 动态选择相关历史 块的功能
方法#
MoBA 核心的方法是 block partitioning (块分区) 和 selection strategy (选择策略)(听着是不是和 NSA 的压缩块然后选择很像🤔)
总体#
MoBA 的方法论很简单, ,假设长度为 ,然后将长度为 的输入分为 个小块,其中定义块的大小为 ,此时定义一个索引,这个索引主要是为了后续选择块
然后计算 在 上的 ,选出最大的 个最大的 块,注意这里的 的计算方式, ,外层尖括号表示内积,mean_pool
表示平均值,相当于计算在 块上的平均 值。
然后 MoBA 提到了在自回归语言模型中保持因果关系很重要,需要确保 无法 Route 到任何未来的 块上。其中一个比较特殊一点的情况是,将 “current block” 定义为包含查询 token 本身的 block,到当前 块的 Route 也可能违反因果关系,因为整个 块的平均池可能会无意中包含来自未来 块的信息。为了解决这个问题,我们强制要求每个 Token 都必须 Route 到其各自的当前 块,并在当前 块 Attention 运算期间应用因果掩码。
最终思考#
MoBA 和 NSA 的核心不一样在哪里:
- MoBA 干的是将 分块,然后选择更小的块进行计算,NSA 干的是压缩后选择小块进行计算再加一个滑动窗口。这个计算的核心逻辑就不一样,MoBA 选择是通过内积的 topk 来选择的,是不需要梯度参与的,而 NSA 的选择其实是会有梯度回传来修正的。、
- NSA 干的是取 KV block 的细粒度,MoBA 干的是让不同的 query head 能接触到不同的 块,侧重点不一样且两者都不能干对方干的事
笔者问题:
- MoBA 里面的分成小块后计算 Attn 分数可以使用 FlashAttention,那为什么 NSA 里面选择完小块后不能这么使用呢?