発展の経緯#
起源#
スパースアテンション(Sparse Attention)は、クエリベクトルと一組のキー・バリューペアを出力ベクトルにマッピングする最適化されたアテンションメカニズムですが、シングルヘッドアテンションやマルチヘッドアテンションとは異なり、クエリベクトルとすべてのキーの類似度を計算するのではなく、クエリベクトルと一部のキーの類似度のみを計算することで、計算量とメモリ消費を削減します。スパースアテンションの概念は、2018 年の論文「Generating Long Sequences with Sparse Transformers」に最初に登場し、この論文では、8000 語を超えるテキストシーケンスを処理するためにスパースアテンションを使用した Transformer ベースの長シーケンス生成モデルが提案されました。
実際に訓練された Transformer モデルでは、アテンションマトリックスはしばしばスパースであることが多く、これはすべてのトークンが他のすべてのトークンに注意を払う必要がないことを意味します。トークン間の相互作用の中には、最終的な出力に大きな貢献をしないものもあり、無視されることがあります。
スパースな方法は、固定パターン(例えば、ローカルウィンドウ)、コンテンツベースの選択(例えば、現在の位置に最も関連する他の位置)、または学習によって得られたパターンである可能性があります。
スパース接続の度合いを決定する基準に基づいて、方法は二つのカテゴリに分けられます:位置ベースのスパースアテンションとコンテンツベースのスパースアテンション。
位置ベースのスパースアテンション#
-
グローバルアテンション
グローバルノードの概念が設定されており、これらのグローバルノードでは、1 つのノードが他のすべてのノードに注意を払うことができます。別の理解の仕方をすれば、これらのノードはすべてのノードが相互に通信するための中継地点として機能します。スパースアテンションの本質は、各ノードが点対点で通信する必要がないということであり、P2P から P2S の概念に似ています。
-
バンドアテンション
データ分布の局所的な性質を考慮し、スライディングウィンドウの概念に似ており、1 つのノードが周囲のノードにのみ注意を払うようにし、アテンションの相互作用をローカルアテンションに制限します(アテンションが得た長距離の受容野がこのように切断されると、壊れてしまいます)。
-
ダイレイテッドアテンション
バンドアテンションの基礎の上に、より遠くの相互作用ノードを設定するためにスペースを空け、スライディングウィンドウにステップサイズの間隔を設定することに相当します(切断した部分の効果が悪いと感じて、延長しようとしていますが、依然として非常に拙劣です)。
-
ランダムアテンション
各クエリに対してランダムにいくつかのエッジをサンプリングして実現し、純粋にランダムな当選のように感じ、効果はあまり良くないでしょう。
-
ブロックアテンション
入力シーケンスをいくつかの重複しないクエリブロックに分割し、各クエリブロックにローカルメモリブロックを割り当てて、長シーケンスの効率的な処理を実現します。(この仕組みはあまり理解できませんが、これはアテンションが長さと幅を 1 つ小さくしただけではないでしょうか🤓)
コンテンツベースのスパースアテンション#
-
最大内積検索(MIPS)
コンテンツベースのスパースグラフを効率的に構築するために、最大内積検索(Maximum Inner Product Search、MIPS)問題の解決策を利用できます。MIPS の目標は、クエリと最大の点積を持つキーを見つけることであり、クエリとすべてのキー間の点積を計算する必要はありません。
NSA(推論効率と訓練の実行可能性)#
起源#
長文のモデリングは大規模モデルにとって非常に重要ですが、従来のアテンションメカニズムは平方の計算複雑度を持ち、コンテキストの長さを増やすと大量の追加計算が必要になります。長さが 2 倍になると、計算量は 3 倍になります
最先端のスパースアテンションは、KV キャッシュの排出とブロック KV キャッシュの選択、サンプリング、クラスタリングの方法に分かれています。しかし、これらの 2 つの方法は、彼らが言うほど良くはありません。
主に以下のいくつかの問題があります:局所的スパース性、アテンションへの適合性、エンドツーエンドの訓練
- 局所的スパース性:H2O のような方法は、自回帰デコード段階でのみスパースマトリックスを適用しましたが、プレフィル中には計算集約型の前処理が必要です。MInference はプレフィル時にのみスパースアテンションを採用しました。これらの方法はすべての段階でスパースアテンションを実現していないため、プレフィル主導の作業(例えば、書籍の要約、コード補完)やデコード主導の作業(例えば、思考チェーン)ではうまく機能せず、下流タスクに対してエンドツーエンドの訓練を実現するための統一されたアーキテクチャがありません。
- アテンションへの適合性:ほとんどのスパースマトリックスは MHA のスパース性を考慮していますが、MQA や GQA のような構造には適合しない場合があります。例えば、Quest メソッドでは、各アテンションヘッドには独立した kv キャッシュがありますが、MQA や GQA のように共用される場合があります。
- エンドツーエンドの訓練:現在、ほとんどのスパースマトリックスは推論タスクに特化しており、訓練タスクに特化したスパースマトリックスが必要ですが、密なマトリックスで訓練されたモデルはスパース推論でうまく機能しません。なぜなら、20%のアテンションが 70%のアテンションスコアをカバーできるからです。さらに、ClusterKV や MagicPIG のような作業は不連続な計算グラフを導入し、逆伝播が正常に実行できなくなります。不連続なメモリアクセスは、FlashAttention などの高速アテンション技術の効果的な適応を妨げ、これらの技術は連続的なメモリアクセスとブロック計算に依存して高スループットを実現します。
NSA の報告では、主に解決すべき 2 つの問題が挙げられています:
- 一つはハードウェアと連携した推論最適化で、プレフィルとデコードの 2 つの段階で理論的な最適化を実際の加速に変えるためには、メモリアクセスとハードウェアボトルネックのスケジューリングに友好的なアルゴリズムが必要です。
- 二つ目は訓練を意識したアルゴリズム設計で、エンドツーエンドの学習スパースパターンをサポートし、従来の「先に訓練し、後に剪定する」性能損失を避けることです。
解決策として提案された方法は、主に 3 つのステップに分かれています:
-
圧縮された粗粒度トークン(cmp)
連続するキー / バリューブロックをブロックレベルの表現に集約し、粗粒度の意味情報をキャッチし、計算負担を軽減します。
言い換えれば、kv の複数の次元を 1 つの次元に統合することです。例えば、1024 次元の kv を 64 次元に変えることです。
-
選択的に保持された細粒度トークン(slc)
重要なトークンを選択的に保持し、圧縮によって生じる情報損失を補います。
言い換えれば、MIPS のように最も関連性の高いトークンアテンションを探し、他のトークンは注目に値しないということです。
-
スライディングウィンドウ(win)
局所的なコンテキストを処理するためのスライディングウィンドウブランチを特別に設け、局所的なパターンが学習プロセスを支配する可能性のある問題を解決します。
言い換えれば、上記のバンドアテンションのことです(壊れてしまったが、実際には役立ちます🤯)。
デモ#
このプロセスを説明するために、簡単な例を挙げることができます:
現在の入力は であり、 と仮定します。次に、長さ 8 で を行うと仮定します。 は対称であるため、 の例を用いると、 を 8 ブロックに分けることができます 。圧縮後、 を と同じサイズのベクトルブロックに変換します。言い換えれば、多くのブロック を 1 つの に変えて、 が占めるメモリを削減し、計算を加速します。この時、元の と圧縮された を用いてアテンションスコアを計算し、圧縮アテンション を得ます。
中間部分は と呼ばれ、圧縮時に圧縮された KV ブロック を得た後、最大のいくつかのアテンションスコアを計算します。ここでは を選択し、 と仮定します。つまり、3 番目のブロックと 7 番目のブロックを選択し、選択された圧縮ブロックを復元します。つまり、 を に拡張して処理し、必要な ブロックを取得し、選択アテンション を計算します。
右側はスライディングウィンドウで、元の の中から最近の 8 つの を選択することで、スライディングウィンドウアテンション を得ることができます。
最後に、ゲート関数を用いて制御します:
節約された を分析すると、元々64 個の があり、圧縮アテンションでは 8 個の を使用し、選択アテンションでは 16 個の を使用し、スライディングウィンドウアテンションでは 8 個の を使用します。つまり、合計で 32 個の のみを使用し、半分の メモリを節約しています。
背景#
アテンション#
新しいクエリ に対して、以前のすべての t 個の ペアをクエリする必要があります。
算術強度#
:メモリへのアクセス時間は、メモリ内でアクセスされるバイト数をプロセッサのメモリ帯域幅で割ったものです。
:数学的な時間は、演算回数をプロセッサの数学帯域幅で割ったものです。
もし であれば、そのアルゴリズムは数学的制約を受けています。
上式は に置き換えることができ、左側はアルゴリズムの実装操作数とアクセスバイト数の比率であり、これをアルゴリズムの算術強度と呼びます。右側はプロセッサの数学帯域幅とメモリ帯域幅の比率であり、これをバイト比率と呼びます。
- アルゴリズムの算術強度が GPU のバイト比率を上回る場合、そのアルゴリズムは計算能力に制約されていると呼ばれ、性能は計算能力
FLOPS
に制約されます(計算能力制約 / 計算集約型演算子)。 - アルゴリズムの算術強度が GPU のバイト比率を下回る場合、そのアルゴリズムはメモリ制約を受けていると呼ばれ、性能はメモリ帯域幅に制約されます(メモリ制約 / メモリ集約型演算子)。
アルゴリズム / ネットワーク層の算術強度を GPU のバイト比率よりも高く保つことができれば、
gpu
の計算能力を十分に活用できます。
プレフィル段階では、大量の因果自己アテンションが示すバッチ行列乗算は高い算術強度を示し、性能は計算能力に制約されます。自回帰のデコード段階では、各トークンを生成するたびに以前のすべての kv キャッシュにアクセスする必要があるため、メモリ帯域幅に制約されます。このような差異は、プレフィルとトレーニングの際に計算の複雑さを減少させ、デコード段階ではメモリアクセスを減少させるという最適化方向の不一致をもたらします。
方法#
二つの方向:アルゴリズム側の設計とカーネル最適化
全体の概要#
元の を、よりコンパクトで情報密度の高い に最適化します。この変換は に基づいて動的に変化します。数式で表すと:
関数マッピングには、上記で述べた cmp、slc、win の 3 つの方法があり、どのマッピングを採用するかを制御するためにゲート因子を使用します。具体的には以下の通りです:
圧縮された粗粒度トークン(圧縮)#
ここで はブロックの長さ、 はブロック間のスライドステップであり、 は学習可能な MLP で、ブロック内のキーを圧縮キーにマッピングします。文中では、通常この は より小さい必要があり、情報の断片化を緩和します。
原文では、圧縮された表現がより粗い粒度の高次の意味情報をキャッチし、アテンションの計算負担を軽減できると述べられています。実験結果が良ければそれが全てです(🤓)。
選択的に保持された細粒度トークン(選択)#
粗粒度のトークンだけを使用することは明らかに不十分であり、大量の細粒度情報が失われます。モデルがより良く理解できるように、細粒度のブロックが必要です。
ブロックの選択:
ハードウェアに優しい考慮とアテンションスコアの固定分布に基づいています。このステップは、現代の GPU で効率的な計算を実現するために重要です。現代の GPU は、連続ブロックへのアクセスのスループットがランダムインデックスに基づく読み取りよりもはるかに優れています。また、ブロック計算は GPU のテンソルコアを最大限に活用できます。アテンションスコアは通常、空間的な連続性を示し、隣接するキーがしばしば同様の重要性レベルを持つことを示しています。これは DS が実験で発見したことです。浅い色の領域は高い注目値を示しています。
重要なアテンションスコアの計算:
すべてのアテンションスコアを計算することは明らかにコストがかかりますが、前のステップで圧縮されたアテンションを計算することでこのコストを削減できます。
しかし、上記は圧縮されたアテンションスコアに基づいています。一般的には、選択ブロックの長さを と定義します。 の場合、 となります。ブロックが不一致の場合、 が与えられると、
しかし、実際には と は異なります。NSA のスキームでは、彼らを一致させます。GQA と MQA では、異なるが同じ KV 値を共有する Q ヘッドに対して、彼らの重要なアテンションスコアは同じです。つまり、すべてのアテンションスコアを合計してこの KV のアテンションスコアとし、このステップで大量のメモリを節約できます。
最大の k 個のアテンションスコアを選択:
最大の k 個のアテンションスコアを選択します。ここで選択されるのは圧縮ブロックであり、 で、 は順位を示します。
圧縮ブロックに基づいて、最初のすべての ブロックを復元します。
スライディングウィンドウ#
アテンションメカニズムにおいて、局所的なパターンは通常より早く適応し、学習プロセスを支配する可能性があります。これにより、モデルが前の 2 つの kv から効果的に学習するのを妨げる可能性があります。この問題を解決するために、専用のスライディングウィンドウブランチが導入され、元のコンテキストを明示的に処理し、他のブランチ(圧縮と選択)がそれぞれの機能を学習することに集中できるようにします。
3 つのブランチはそれぞれ独立したキーとバリューを提供します。このアーキテクチャ設計は、局所的およびグローバル間の勾配干渉を防ぐことで安定した学習を実現し、最小限のオーバーヘッドを導入します。 の 6 つの KV 値を取得した後、ゲート制御の方法で結果を取得します。
カーネル設計#
トレーニングとプレフィル段階で FlashAttention レベルの加速を実現するために、Triton を利用してハードウェアに整合したスパースマトリックスを実現しました。
圧縮段階とスライディングウィンドウでは、FlashAttention をうまく活用して最適化できます。したがって、ここで言及されるカーネル最適化は、選択段階で生成される離散アテンションシーケンスの計算に主に焦点を当てています。
GQA と MQA の最適化に関して、FlashAttention の戦略に従い、時間的に連続したクエリブロックを SRAM にロードすると、メモリアクセスの効率が低下します。なぜなら、ブロック内のクエリが交差しない KV ブロックを必要とする可能性があるからです。この問題を解決するために、GQA 内のすべての共有する同じ kv ブロックのクエリヘッドを SRAM に一緒にロードします。
グループセンターデータのロード:
内ループ内で、 個のヘッドを持つクエリ をロードし、次に彼が属する圧縮されたインデックス を見つけます。
共有 KV:
内ループ内で、 に基づいて必要な を選択します。ここで は を満たす最小のカーネルブロックサイズです。
上記の緑のブロックは、1 つの q と一連の kv 計算を表しています。ここで、 が増加するにつれて、選択した KV ブロックが常に 3 ブロック以下であるため、NSA の加速がより顕著になります。
MoBA(大道至簡、即挿即用)#
背景#
スパース性については、アテンションスコアのスパース性だけでなく、記憶ストレージに関連する脳領域で観察されたスパース接続特性についても言及されています(ACL に投票できるかもしれません🤓)。
従来のスパースアテンションには二つの大きな欠陥があります。一つは、特定のタスクに基づいた事前定義された構造を採用しており、汎用性が非常に低いことです。もう一つは、動的にトークンを選択してスパースアテンションを訓練する方法であり、この方法は一般的に訓練段階には全く役立ちません。
NSA が解決しようとしている推論速度と訓練可能性の 2 つの問題に似て、MoBA が解決する問題も推論の加速と訓練可能性です。ブロックアテンション混合メカニズム(MoBA、Mixture of Block Attention)は、MLP からアテンションへの専門家混合(MoE)を移行します。
コアの方法は、各 に対して関連する過去の ブロックを動的に選択する機能です。
方法#
MoBA のコアメソッドは、ブロックパーティショニング(ブロック分割)と選択戦略(選択戦略)です(NSA の圧縮ブロックと選択に似ているように聞こえます🤔)。
全体#
MoBA の方法論は非常にシンプルで、 と仮定します。長さを とし、長さ の入力を 個の小さなブロックに分割します。ブロックのサイズを と定義し、この時、後のブロック選択のためのインデックスを定義します。
次に、 が 上での を計算し、最大の 個のブロックを選択します。ここでの の計算方法は、 です。外側の尖括弧は内積を示し、mean_pool
は平均値を示し、 ブロックでの平均 値を計算することに相当します。
MoBA では、自回帰言語モデルにおいて因果関係を維持することが重要であると述べています。 が将来の ブロックにルートできないことを確認する必要があります。特に特殊な状況の一つは、「現在のブロック」をクエリトークン自体を含むブロックとして定義することです。現在の ブロックへのルートも因果関係に違反する可能性があります。なぜなら、全体の ブロックの平均プールが将来の ブロックの情報を無意識に含む可能性があるからです。この問題を解決するために、各トークンがそれぞれの現在の ブロックにルートする必要があり、現在の ブロックのアテンション演算中に因果マスクを適用します。
最終的な考察#
MoBA と NSA のコアの違いは以下の通りです:
- MoBA は を分割し、より小さなブロックを選択して計算を行います。一方、NSA は圧縮後に小さなブロックを選択して計算し、さらにスライディングウィンドウを加えます。この計算のコアロジックは異なります。MoBA の選択は内積の topk によって行われ、勾配が関与する必要はありませんが、NSA の選択は勾配が戻って修正されます。
- NSA は KV ブロックの細粒度を取得しますが、MoBA は異なるクエリヘッドが異なる ブロックにアクセスできるようにします。焦点が異なり、両者は互いにできないことを行います。
著者の疑問:
- MoBA の中で小さなブロックに分割した後、アテンションスコアを計算するために FlashAttention を使用できますが、なぜ NSA の中で小さなブロックを選択した後にそれを使用できないのでしょうか?