SageAttention2 Technical Report:
Accurate 4 Bit Attention for Plug-and-play Inference Acceleration
Jintao Zhang∗
Tsinghua University
Haofeng Huang∗
Tsinghua University
Pengle Zhang
Tsinghua University
Jia Wei
Tsinghua University
Jun Zhu
Tsinghua University
Jianfei Chen
Tsinghua University
Abstract 線形層に対する量子化は広く使用されてきたが、注意機構の加速への応用は限定的であった。SageAttentionは8ビット行列乗算、16ビットアキュムレータを用いた16ビット行列乗算、および精度向上手法を活用し、FlashAttention2と比較して精度を保ちつつ2倍の高速化を実現するカーネルを実装している。注意機構の計算効率をさらに向上させつつ精度を維持するため、我々はSageAttention2 を提案する。これは大幅に高速な4ビット行列乗算(Matmul)と追加の精度向上技術を活用している。第一に、我々はワープレベルの粒度で行列( Q , K ) 𝑄 𝐾 (Q,K) ( italic_Q , italic_K ) をINT4に量子化し、行列( P ~ , V ) ~ 𝑃 𝑉 (\widetilde{P},V) ( over~ start_ARG italic_P end_ARG , italic_V ) をFP8に量子化することを提案する。第二に、Q 𝑄 Q italic_Q とV 𝑉 V italic_V を平滑化する手法を提案し、INT4 Q K 𝑄 𝐾 QK italic_Q italic_K とFP8 P V 𝑃 𝑉 PV italic_P italic_V を用いた注意機構の精度を向上させる。第三に、タイムステップと層にわたる量子化精度を分析し、様々なモデルにおけるエンドツーエンドの評価指標を保証する適応的量子化手法を提案する。SageAttention2 の1秒あたりの演算数(OPS)は、RTX4090上でFlashAttention2とxformersをそれぞれ約3倍 および5倍 上回る。包括的な実験により、我々のアプローチが大規模言語処理、画像生成、動画生成を含む多様なモデルにおいて、エンドツーエンドの評価指標の損失が無視できるほど小さいことが確認された。コードはhttps://github.com/thu-ml/SageAttention で入手可能である 。
1 Introduction
アテンションの二次計算量の複雑さにより、実世界のアプリケーションにおいて配列が長くなるにつれて、効率的な実装がますます重要になっている (Jiang et al., 2024 ) 。アテンションの計算需要を軽減するために、いくつかの戦略が開発されてきた。例えば、(1) 複雑さを O ( N ) 𝑂 𝑁 O(N) italic_O ( italic_N ) に削減する線形アテンション 手法 (Wang et al., 2020 ; Choromanski et al., 2020 ; Yu et al., 2022 ; Katharopoulos et al., 2020 ) 、(2) コンテキストの一部を選択的に処理するスパースアテンション 手法 (Liu et al., 2021 ; Chu et al., 2021 ; Li et al., ; Xiao et al., 2023b , 2024 ; Chen et al., 2023 ; Jiang et al., 2024 ; Venkataramanan et al., 2023 ; Gao et al., 2024 ; Fu et al., 2024 ) などがある。しかし、これらの手法は限られた範囲のモデルやタスクにのみ適している。広く使用されているアテンション手法は、ハードウェアの能力を活用して計算速度を向上させることでアテンションを最適化している。例えば、FlashAttention (Dao et al., 2022 ) 、FlashAttention2 (Dao, 2023 ) 、FlashAttention3 (Shah et al., 2024 ) 、xformers (Lefaudeux et al., 2022 ) 、SageAttention (Zhang et al., 2024 ) などがある。これらの研究は、コンテキストの一部に対するアテンションの計算を省略せず、様々なモデルやタスクにおいて印象的な速度と精度の性能を達成している。
注意機構における2つの行列乗算(Matmul)操作:Q K ⊤ 𝑄 superscript 𝐾 top QK^{\top} italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT とP V 𝑃 𝑉 PV italic_P italic_V について、SageAttentionはQ K ⊤ 𝑄 superscript 𝐾 top QK^{\top} italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT をINT8に量子化し、P V 𝑃 𝑉 PV italic_P italic_V にはFP16アキュムレータを用いたFP16 Matmulを使用することで高速化を実現している。さらに、注意機構の精度を維持するため、SageAttentionはK 𝐾 K italic_K のチャンネル単位の外れ値を除去することでスムージングを提案している。SageAttentionはFlashAttention2とxformersに比べて2× \times × および2.7× \times × の高速化を達成し、言語、画像、動画生成モデルにおいてエンドツーエンドの評価指標の損失が無視できるレベルである初めての量子化された注意機構となっている。しかしながら、SageAttentionには2つの弱点がある。(W1) INT8 MatmulはINT4の半分の速度しか達成できない。(W2) FP16アキュムレータを用いたFP16 MatmulはRTX4090とRTX3090 GPUにのみ対応している。
Q K ⊤ 𝑄 superscript 𝐾 top QK^{\top} italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT にはより高速なINT4テンソルコアを活用し、P V 𝑃 𝑉 PV italic_P italic_V を一般的に高速化できる手法を使用するため、我々はSageAttention2 を提案する。これはQ , K 𝑄 𝐾
Q,K italic_Q , italic_K をINT4に、P , V 𝑃 𝑉
P,V italic_P , italic_V をFP8に量子化するものである。
課題 。Q , K 𝑄 𝐾
Q,K italic_Q , italic_K をINT4に、P , V 𝑃 𝑉
P,V italic_P , italic_V をFP8に量子化することは重大な課題を提示する。例えば、Q , K 𝑄 𝐾
Q,K italic_Q , italic_K をINT4にテンソル単位で量子化するだけでも、テキストから動画を生成するモデルCogvideoXは完全にぼやけた動画を生成し(図2 参照)、Llama2はMMULUにおいてランダム推測レベルの25%の精度しか達成できない。詳細に調査した結果、我々は3つの主要な課題を特定した:(C1) INT4の数値範囲は、量子化において通常− 7 7 -7 - 7 から7 7 7 7 までの15の数字を含むが(Lin et al., 2024 ) 、これはQ 𝑄 Q italic_Q とK 𝐾 K italic_K に異常値がある場合、重大な量子化誤差につながる。(C2) 一部のモデルの特定の層とタイムステップ(テキストから画像/動画の場合)において、Q 𝑄 Q italic_Q とK 𝐾 K italic_K をINT4に、P 𝑃 P italic_P とV 𝑉 V italic_V をFP8に量子化すると、注意計算に顕著な誤差が生じる。これらの最悪のケースの層/タイムステップにおける誤差は、エンドツーエンドの出力の精度に大きな影響を与える。
(C3) 我々は、テンソルコアにおけるFP8行列乗算用に設計されたFP32アキュムレータ(mma.f32.f8.f8.f32 )が実際にはFP22、具体的には1符号ビット、8指数ビット、13仮数ビットであることを発見した。これによりP V 𝑃 𝑉 PV italic_P italic_V の精度損失が生じる。
図2: CogvideoXからQ、KをINT4に量子化した例。
我々のアプローチ 。これらの課題に対処するため、我々は理由を詳細に分析し、2つの方法を提案する。第一に、行列Q 𝑄 Q italic_Q とK 𝐾 K italic_K におけるチャンネル方向の顕著な外れ値に対して、SageAttentionにおいて平滑化K 𝐾 K italic_K を採用し、Q 𝑄 Q italic_Q におけるこれらの外れ値を除去する効果的な方法を提案する。具体的には、Q 𝑄 Q italic_Q のチャンネル次元の平均値を減算することを提案し、これをQ → m subscript → 𝑄 𝑚 \overrightarrow{Q}_{m} over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT と呼ぶ。その後、Q → m K subscript → 𝑄 𝑚 𝐾 \overrightarrow{Q}_{m}K over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT italic_K をQ K 𝑄 𝐾 QK italic_Q italic_K Matmulの後に加えることで、注意計算の正確性を確保する。第二に、特定の層とタイムステップが異なる入力間で一貫して量子化の課題を示すことを観察した。精度を維持するために、適応的な混合精度法を適用する。具体的には、これらの問題のある層とタイムステップに対しては8ビット(INT8+FP8)の注意を、その他に対しては4ビット(INT4+FP8)の注意を使用する。第三に、P V 𝑃 𝑉 PV italic_P italic_V のFP8 Matmulに22ビットのアキュムレータを使用することに関連する精度損失を軽減するために、V 𝑉 V italic_V を平滑化して精度性能を向上させる方法を提案する。
性能。 重要なことに、我々はRTX4090およびL20 GPU上でSageAttention2 の高性能実装を提供する。この実装はRTX4090上でピーク性能485 TOPS を達成し、FlashAttention2とxformersをそれぞれ約3.1倍および5.4倍上回る。我々はSageAttention2 を使用して、最先端のテキスト、画像、および動画生成モデルのエンドツーエンドメトリクスを広範に評価した。すべてのモデルとタスクにおいて、SageAttention2 はモデル性能の無視できる程度の損失で直接プラグアンドプレイ方式で採用できる。
2 Preliminary
2.1 FlashAttention
自己注意の計算は以下のように定式化できる:S = Q K ⊤ / d , P = σ ( S ) , O = P V formulae-sequence 𝑆 𝑄 superscript 𝐾 top 𝑑 formulae-sequence 𝑃 𝜎 𝑆 𝑂 𝑃 𝑉 S=QK^{\top}/\sqrt{d},~{}P=\sigma(S),~{}O=PV italic_S = italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT / square-root start_ARG italic_d end_ARG , italic_P = italic_σ ( italic_S ) , italic_O = italic_P italic_V 、ここでσ ( S ) i j = exp ( S i j ) / ∑ k exp ( S i k ) 𝜎 subscript 𝑆 𝑖 𝑗 subscript 𝑆 𝑖 𝑗 subscript 𝑘 subscript 𝑆 𝑖 𝑘 \sigma(S)_{ij}=\exp(S_{ij})/\sum_{k}\exp(S_{ik}) italic_σ ( italic_S ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = roman_exp ( italic_S start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) / ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_exp ( italic_S start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT ) はソフトマックス演算である。
行列Q 𝑄 Q italic_Q 、K 𝐾 K italic_K 、およびV 𝑉 V italic_V はそれぞれN × d 𝑁 𝑑 N\times d italic_N × italic_d の次元を持ち、一方で行列S 𝑆 S italic_S 、P 𝑃 P italic_P はN × N 𝑁 𝑁 N\times N italic_N × italic_N である。d 𝑑 d italic_d は通常小さく、例えば64や128であるが、N 𝑁 N italic_N は数千あるいは数百万にもなり得る。したがって、N × N 𝑁 𝑁 N\times N italic_N × italic_N 行列( S , P ) 𝑆 𝑃 (S,P) ( italic_S , italic_P ) は( Q , K , V ) 𝑄 𝐾 𝑉 (Q,K,V) ( italic_Q , italic_K , italic_V ) よりもはるかに大きく、素朴な実装では( S , P ) 𝑆 𝑃 (S,P) ( italic_S , italic_P ) の読み書きに膨大なグローバルメモリIOが必要となる。FlashAttention (Dao, 2023 ) は、Q 𝑄 Q italic_Q 、K 𝐾 K italic_K 、およびV 𝑉 V italic_V をトークン次元からブロックサイズb q subscript 𝑏 𝑞 b_{q} italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT 、b k v subscript 𝑏 𝑘 𝑣 b_{kv} italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT 、b k v subscript 𝑏 𝑘 𝑣 b_{kv} italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT のブロック{ Q i } , { K i } , { V i } subscript 𝑄 𝑖 subscript 𝐾 𝑖 subscript 𝑉 𝑖
\{Q_{i}\},\{K_{i}\},\{V_{i}\} { italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } , { italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } , { italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } にタイル化することを提案している。そして、( S , P ) 𝑆 𝑃 (S,P) ( italic_S , italic_P ) のメモリIOを避けるために、オンラインソフトマックス(Milakov & Gimelshein, 2018 ) を使用してO 𝑂 O italic_O の各ブロック、つまりO i subscript 𝑂 𝑖 O_{i} italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT を段階的に計算する:
まず、{ K i } , { V i } subscript 𝐾 𝑖 subscript 𝑉 𝑖
\{K_{i}\},\{V_{i}\} { italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } , { italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } の各ブロックに対して、以下の方程式を反復的に計算する:
S i j = Q i K j ⊤ / d , ( m i j , P ~ i j ) = σ ~ ( m i j − 1 , S i j ) , l i j = \displaystyle S^{j}_{i}=Q_{i}K_{j}^{\top}/\sqrt{d},~{}~{}(m^{j}_{i},\widetilde%
{P}^{j}_{i})=\tilde{\sigma}(m^{j-1}_{i},S^{j}_{i}),~{}~{}l_{i}^{j}= italic_S start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT / square-root start_ARG italic_d end_ARG , ( italic_m start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , over~ start_ARG italic_P end_ARG start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = over~ start_ARG italic_σ end_ARG ( italic_m start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_S start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT =
exp ( m i j − m i j − 1 ) l i j − 1 + rowsum ( P ~ i j ) , subscript superscript 𝑚 𝑗 𝑖 subscript superscript 𝑚 𝑗 1 𝑖 superscript subscript 𝑙 𝑖 𝑗 1 rowsum subscript superscript ~ 𝑃 𝑗 𝑖 \displaystyle\exp(m^{j}_{i}-m^{j-1}_{i})l_{i}^{j-1}+\mathrm{rowsum}(\widetilde%
{P}^{j}_{i}), roman_exp ( italic_m start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_m start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT + roman_rowsum ( over~ start_ARG italic_P end_ARG start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,
O i j = diag ( exp ( m i j − m i j − 1 ) ) O i j − 1 subscript superscript 𝑂 𝑗 𝑖 diag subscript superscript 𝑚 𝑗 𝑖 subscript superscript 𝑚 𝑗 1 𝑖 subscript superscript 𝑂 𝑗 1 𝑖 \displaystyle O^{j}_{i}=\mathrm{diag}\left(\exp(m^{j}_{i}-m^{j-1}_{i})\right)O%
^{j-1}_{i} italic_O start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_diag ( roman_exp ( italic_m start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_m start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) italic_O start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
+ P ~ i j V j subscript superscript ~ 𝑃 𝑗 𝑖 subscript 𝑉 𝑗 \displaystyle+\widetilde{P}^{j}_{i}V_{j} + over~ start_ARG italic_P end_ARG start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT
ここで、m i j superscript subscript 𝑚 𝑖 𝑗 m_{i}^{j} italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT とl i j superscript subscript 𝑙 𝑖 𝑗 l_{i}^{j} italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT はb q × 1 subscript 𝑏 𝑞 1 b_{q}\times 1 italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × 1 ベクトルであり、それぞれ− ∞ -\infty - ∞ と0 0 で初期化される。σ ~ ( ) ~ 𝜎
\tilde{\sigma}() over~ start_ARG italic_σ end_ARG ( ) はオンラインソフトマックス演算子である:m i j = max { m i j − 1 , rowmax ( S i j ) } , P ~ j i = exp ( S i j − m i j ) formulae-sequence subscript superscript 𝑚 𝑗 𝑖 subscript superscript 𝑚 𝑗 1 𝑖 rowmax subscript superscript 𝑆 𝑗 𝑖 superscript subscript ~ 𝑃 𝑗 𝑖 subscript superscript 𝑆 𝑗 𝑖 subscript superscript 𝑚 𝑗 𝑖 m^{j}_{i}=\max\{m^{j-1}_{i},\mathrm{rowmax}(S^{j}_{i})\},~{}\widetilde{P}_{j}^%
{i}=\exp(S^{j}_{i}-m^{j}_{i}) italic_m start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_max { italic_m start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , roman_rowmax ( italic_S start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } , over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = roman_exp ( italic_S start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_m start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) 。
最後に、出力O i subscript 𝑂 𝑖 O_{i} italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT はO i = diag ( l i j ) − 1 O i j subscript 𝑂 𝑖 diag superscript superscript subscript 𝑙 𝑖 𝑗 1 superscript subscript 𝑂 𝑖 𝑗 O_{i}=\mathrm{diag}(l_{i}^{j})^{-1}O_{i}^{j} italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_diag ( italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT によって計算できる。
2.2 Quantization
行列乗算 C = A B 𝐶 𝐴 𝐵 C=AB italic_C = italic_A italic_B は、以下のように量子化によって加速できる:
( δ A , A ^ ) = ψ ( A ) , ( δ B , B ^ ) = ψ ( B ) , C ^ = A ^ B ^ , formulae-sequence subscript 𝛿 𝐴 ^ 𝐴 𝜓 𝐴 formulae-sequence subscript 𝛿 𝐵 ^ 𝐵 𝜓 𝐵 ^ 𝐶 ^ 𝐴 ^ 𝐵 \displaystyle(\delta_{A},\hat{A})=\psi(A),~{}~{}(\delta_{B},\hat{B})=\psi(B),~%
{}~{}\hat{C}=\hat{A}\hat{B}, ( italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , over^ start_ARG italic_A end_ARG ) = italic_ψ ( italic_A ) , ( italic_δ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , over^ start_ARG italic_B end_ARG ) = italic_ψ ( italic_B ) , over^ start_ARG italic_C end_ARG = over^ start_ARG italic_A end_ARG over^ start_ARG italic_B end_ARG ,
C = ψ δ A δ B − 1 ( C ^ ) 𝐶 subscript superscript 𝜓 1 subscript 𝛿 𝐴 subscript 𝛿 𝐵 ^ 𝐶 \displaystyle~{}~{}C=\psi^{-1}_{\delta_{A}\delta_{B}}(\hat{C}) italic_C = italic_ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_C end_ARG )
(1)
ψ 𝜓 \psi italic_ψ は量子化器 であり、高精度(例:FP32またはFP16)の行列 A 𝐴 A italic_A を低精度形式 A ^ ^ 𝐴 \hat{A} over^ start_ARG italic_A end_ARG (例:INT4またはFP8)にスケール δ A subscript 𝛿 𝐴 \delta_{A} italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT で変換し、ψ − 1 superscript 𝜓 1 \psi^{-1} italic_ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT は高精度に戻す逆量子化器 である。我々は ψ δ A − 1 ( A ^ ) ≈ A subscript superscript 𝜓 1 subscript 𝛿 𝐴 ^ 𝐴 𝐴 \psi^{-1}_{\delta_{A}}(\hat{A})\approx A italic_ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_A end_ARG ) ≈ italic_A を持つべきである。実際の行列乗算 A ^ B ^ ^ 𝐴 ^ 𝐵 \hat{A}\hat{B} over^ start_ARG italic_A end_ARG over^ start_ARG italic_B end_ARG は低精度で行われる。現代のGPUでは、低精度の行列乗算は通常、高精度のものよりも数倍高速である。
多くの量子化器は、数値形式と粒度に依存する。例えば、共通のスケール因子を共有する要素の数などである。
例えば、INT4のテンソル単位の量子化器 は、まずテンソル全体の絶対値の最大値としてスケールを計算し、要素をINT4の最大表現可能範囲[-7, +7]にスケーリングし、その後四捨五入してINT4にキャストする:A ^ = ⌈ A / δ A ⌋ , δ A = max ( | A | ) / 7 \hat{A}=\lceil A/\delta_{A}\rfloor,\delta_{A}=\max(|{A}|)/7 over^ start_ARG italic_A end_ARG = ⌈ italic_A / italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ⌋ , italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT = roman_max ( | italic_A | ) / 7 。
同様に、トークン単位の量子化器 はテンソルの各トークンにスケール因子を割り当てる:A ^ [ i , : ] = ⌈ A [ i , : ] / δ A [ i ] ⌋ , δ A [ i ] = max ( | A [ i , : ] | ) / 7 \hat{A}[i,:]=\lceil A[i,:]/\delta_{A}[i]\rfloor,\delta_{A}[i]=\max(|{A[i,:]}|)/7 over^ start_ARG italic_A end_ARG [ italic_i , : ] = ⌈ italic_A [ italic_i , : ] / italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT [ italic_i ] ⌋ , italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT [ italic_i ] = roman_max ( | italic_A [ italic_i , : ] | ) / 7 。
また、チャンネル単位の量子化器 はテンソルの各チャンネルにスケール因子を割り当てる。つまり、チャンネル次元に沿って:A [ : , i ] = ⌈ A [ : , i ] / δ A [ i ] ⌋ , δ A [ i ] = max ( | A [ : , i ] | ) / 7 A[:,i]=\lceil A[:,i]/\delta_{A}[i]\rfloor,\delta_{A}[i]=\max(|A[:,i]|)/{7} italic_A [ : , italic_i ] = ⌈ italic_A [ : , italic_i ] / italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT [ italic_i ] ⌋ , italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT [ italic_i ] = roman_max ( | italic_A [ : , italic_i ] | ) / 7 。
逆量子化プロセスは要素ごとのスケーリングである:ψ δ A δ B − 1 ( A ^ B ^ ) = A ^ B ^ ∗ δ A ∗ δ B superscript subscript 𝜓 subscript 𝛿 𝐴 subscript 𝛿 𝐵 1 ^ 𝐴 ^ 𝐵 ^ 𝐴 ^ 𝐵 subscript 𝛿 𝐴 subscript 𝛿 𝐵 \psi_{\delta_{A}\delta_{B}}^{-1}(\hat{A}\hat{B})=\hat{A}\hat{B}*\delta_{A}*%
\delta_{B} italic_ψ start_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( over^ start_ARG italic_A end_ARG over^ start_ARG italic_B end_ARG ) = over^ start_ARG italic_A end_ARG over^ start_ARG italic_B end_ARG ∗ italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ∗ italic_δ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT 。
2.3 SageAttention
FlashAttention2 (Dao, 2023 ) のブロックタイリングに基づき、SageAttention (Zhang et al., 2024 ) はブロック単位の粒度で Q , K 𝑄 𝐾
Q,K italic_Q , italic_K をINT8に量子化する。さらに、量子化の精度を保つために、SageAttentionは最初に K 𝐾 K italic_K を平滑化することを提案している:
K = K 𝐾 𝐾 \displaystyle K=K italic_K = italic_K
− mean ( K ) mean 𝐾 \displaystyle-\text{mean}(K) - mean ( italic_K )
Q ^ i = ⌈ Q i / δ Q ⌋ , δ Q = max ( | Q i | ) / 127 , \displaystyle\hat{Q}_{i}=\lceil Q_{i}/\delta_{Q}\rfloor,~{}~{}\delta_{Q}=\max(%
|{Q_{i}|})/127, over^ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ⌈ italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ⌋ , italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT = roman_max ( | italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | ) / 127 ,
K ^ j = ⌈ K j / δ K ⌋ , δ K = max ( | K j | ) / 127 \displaystyle~{}~{}\hat{K}_{j}=\lceil K_{j}/\delta_{K}\rfloor,~{}~{}\delta_{K}%
=\max(|{K_{j}|})/127 over^ start_ARG italic_K end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = ⌈ italic_K start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT / italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ⌋ , italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = roman_max ( | italic_K start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT | ) / 127
S i j = subscript 𝑆 𝑖 𝑗 absent \displaystyle S_{ij}= italic_S start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT =
Q i K j ⊤ ∗ δ Q ∗ δ K subscript 𝑄 𝑖 superscript subscript 𝐾 𝑗 top subscript 𝛿 𝑄 subscript 𝛿 𝐾 \displaystyle Q_{i}K_{j}^{\top}*\delta_{Q}*\delta_{K} italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∗ italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∗ italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT
ここで、Q i , K j subscript 𝑄 𝑖 subscript 𝐾 𝑗
Q_{i},K_{j} italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_K start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT はFlashAttentionでタイル化されたブロックであり、mean ( K ) mean 𝐾 \text{mean}(K) mean ( italic_K ) は K 𝐾 K italic_K のチャンネル次元の平均値である。SageAttentionは P ~ V ~ 𝑃 𝑉 \widetilde{P}V over~ start_ARG italic_P end_ARG italic_V をFP16として保持し、P ~ V ~ 𝑃 𝑉 \widetilde{P}V over~ start_ARG italic_P end_ARG italic_V にはFP16アキュムレータを使用したFP16行列乗算を使用する。しかし、FP16アキュムレータを使用したFP16行列乗算は、RTX4090およびRTX3090 GPUでのみ高速化効果がある。
3 SageAttention-2
図3: SageAttention2 のワークフロー。 1 Q,K,Vを平滑化する。 2 GEMVを実行してΔ S Δ 𝑆 \Delta S roman_Δ italic_S を得る。 3 ワープごとにQ,Kを量子化し、チャンネルごとにVを量子化する。 4 SageAttention2 カーネルを実行する。 5 出力を修正する。
3.1 Formulation
第2 節で紹介したFlashAttentionと量子化に基づき、我々が開発した量子化注意機構アプローチについて説明する。
Quantization:
( δ Q , Q ^ , Q → m ) = ϕ Q ( Q ) , ( δ K , K ^ ) = ϕ K ( K ) , ( δ V , V ^ , V → m ) = ϕ V ( V ) formulae-sequence subscript 𝛿 𝑄 ^ 𝑄 subscript → 𝑄 𝑚 subscript italic-ϕ 𝑄 𝑄 formulae-sequence subscript 𝛿 𝐾 ^ 𝐾 subscript italic-ϕ 𝐾 𝐾 subscript 𝛿 𝑉 ^ 𝑉 subscript → 𝑉 𝑚 subscript italic-ϕ 𝑉 𝑉 \displaystyle{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1%
}{(\delta_{Q},\hat{Q},~{}~{}\overrightarrow{Q}_{m})}}={\color[rgb]{0,0,1}%
\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\phi_{Q}}}(Q),~{}~{}{\color[%
rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{(\delta_{K},\hat{K}%
)}}={\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\phi_{K%
}}}(K),~{}~{}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1%
}{(\delta_{V},\hat{V},\overrightarrow{V}_{m})}}={\color[rgb]{0,0,1}%
\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\phi_{V}}}(V) ( italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , over^ start_ARG italic_Q end_ARG , over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) = italic_ϕ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q ) , ( italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , over^ start_ARG italic_K end_ARG ) = italic_ϕ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K ) , ( italic_δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT , over^ start_ARG italic_V end_ARG , over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) = italic_ϕ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V )
Δ S = Q m → K ⊤ Δ 𝑆 → subscript 𝑄 𝑚 superscript 𝐾 top \displaystyle{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1%
}{\Delta S=\overrightarrow{Q_{m}}K^{\top}}} roman_Δ italic_S = over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT
(2)
Attention:
S = ψ δ Q δ K − 1 ( Q ^ K ^ ⊤ ) + Δ S , ( m ′ , P ~ ) = σ ~ ( m , S ) , ( δ P , P ^ ) = ψ P ( P ~ ) formulae-sequence 𝑆 subscript superscript 𝜓 1 subscript 𝛿 𝑄 subscript 𝛿 𝐾 ^ 𝑄 superscript ^ 𝐾 top Δ 𝑆 formulae-sequence superscript 𝑚 ′ ~ 𝑃 ~ 𝜎 𝑚 𝑆 subscript 𝛿 𝑃 ^ 𝑃 subscript 𝜓 𝑃 ~ 𝑃 \displaystyle S={\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{%
0,0,1}{\psi^{-1}_{\delta_{Q}\delta_{K}}(\hat{Q}\hat{K}^{\top})+\Delta S}},~{}~%
{}(m^{\prime},\widetilde{P})=\tilde{\sigma}(m,S),~{}~{}{\color[rgb]{0,0,1}%
\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{(\delta_{P},\hat{P})}}={\color%
[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\psi_{P}}}(%
\widetilde{P}) italic_S = italic_ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_Q end_ARG over^ start_ARG italic_K end_ARG start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) + roman_Δ italic_S , ( italic_m start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , over~ start_ARG italic_P end_ARG ) = over~ start_ARG italic_σ end_ARG ( italic_m , italic_S ) , ( italic_δ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT , over^ start_ARG italic_P end_ARG ) = italic_ψ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( over~ start_ARG italic_P end_ARG )
O = diag ( exp ( m ′ − m ) ) O + ψ δ P δ V − 1 ( P ^ V ^ ) 𝑂 diag superscript 𝑚 ′ 𝑚 𝑂 subscript superscript 𝜓 1 subscript 𝛿 𝑃 subscript 𝛿 𝑉 ^ 𝑃 ^ 𝑉 \displaystyle O=\mathrm{diag}\left(\exp(m^{\prime}-m)\right)O+{\color[rgb]{%
0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\psi^{-1}_{\delta_{P}%
\delta_{V}}(\hat{P}\hat{V})}} italic_O = roman_diag ( roman_exp ( italic_m start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - italic_m ) ) italic_O + italic_ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_P end_ARG over^ start_ARG italic_V end_ARG )
(3)
ϕ Q subscript italic-ϕ 𝑄 \phi_{Q} italic_ϕ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT 、ϕ K subscript italic-ϕ 𝐾 \phi_{K} italic_ϕ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT 、ϕ V subscript italic-ϕ 𝑉 \phi_{V} italic_ϕ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT は、量子化されたQ 𝑄 Q italic_Q 、K 𝐾 K italic_K 、V 𝑉 V italic_V を得るための3つの変換であり、これらについては後続のセクションで議論する。
簡略化のため、すべての上付き文字と下付き文字は省略するが、注意機構で使用される行列は依然としてタイルであり、計算は第2.1 節で説明したFlashAttentionとして組織化されている。式2 、3 に示されている元の全精度版と比較して、SageAttention2 はQ , K , P , V 𝑄 𝐾 𝑃 𝑉
Q,K,P,V italic_Q , italic_K , italic_P , italic_V に量子化器を追加し、積に逆量子化器を追加することで、Q K 𝑄 𝐾 QK italic_Q italic_K とP ~ V ~ 𝑃 𝑉 \widetilde{P}V over~ start_ARG italic_P end_ARG italic_V の両方の行列乗算を加速している。
図4: 注意機構におけるテンソルのデータ分布の典型的な例。
表1: 異なる量子化手法のエンドツーエンドメトリクスの比較。Q,Kは4ビット整数に量子化され、P,Vは全精度のままである。
Q, K
Smoothing
(Q+K)
Llama3.1
(Lambda)
↑ ↑ \uparrow ↑
Llama3.1
(WikeText)
↓ ↓ \downarrow ↓
CogVideo
(vqa-a)
↑ ↑ \uparrow ↑
CogVideo
(vqa-t)
↑ ↑ \uparrow ↑
Full-Precision
-
81.5%
6.013
77.605
75.360
INT4 Quantization
✗
72.6%
11.698
27.114
24.670
✓
80.8%
6.219
77.276
75.147
3.2 Per-warp INT4 Quantization
SageAttentionはブロック単位の量子化を使用しており、GPUのストリーミングプロセッサごとにQ 𝑄 Q italic_Q とK 𝐾 K italic_K の各ブロックを量子化する。このような量子化戦略は、トークン単位の量子化に近い精度性能を達成し、量子化スケールベクトルδ Q subscript 𝛿 𝑄 \delta_{Q} italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT とδ K subscript 𝛿 𝐾 \delta_{K} italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT のドット積のオーバーヘッドを回避できる。しかし、Q 𝑄 Q italic_Q とK 𝐾 K italic_K をINT4に量子化するには、より正確な量子化粒度が必要である。我々はワープ単位の量子化 を提案する。これはブロック単位の量子化器 よりも精密で粒度の細かい量子化アプローチであり、ベクトルのドット積による追加のオーバーヘッドもない。
具体的には、SageAttentionにおけるQ 𝑄 Q italic_Q の各ブロックはc w subscript 𝑐 𝑤 c_{w} italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT 個のセグメントに分割され、それぞれがGPUストリーミングプロセッサ(SM)内のc w subscript 𝑐 𝑤 c_{w} italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT 個のGPUワープの1つによって処理される。その後、各ワープが割り当てられたセグメントに対してMatmulを実行する。ワープ単位のINT4量子化は、Q 𝑄 Q italic_Q のb q / w c subscript 𝑏 𝑞 𝑤 𝑐 b_{q}/{wc} italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT / italic_w italic_c トークンごとにスケール因子を割り当てる:
Q ^ [ i ∗ b q c w \displaystyle\hat{Q}[\frac{i*b_{q}}{c_{w}} over^ start_ARG italic_Q end_ARG [ divide start_ARG italic_i ∗ italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG
: b q ∗ ( i + 1 ) c w , : ] = ⌈ Q [ i ∗ b q c w : b q ∗ ( i + 1 ) c w , : ] δ Q [ i ] ⌋ \displaystyle:\frac{b_{q}*(i+1)}{c_{w}},:]=\left\lceil\frac{Q[\frac{i*b_{q}}{c%
_{w}}:\frac{b_{q}*(i+1)}{c_{w}},:]}{\delta_{Q}[i]}\right\rfloor : divide start_ARG italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ∗ ( italic_i + 1 ) end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG , : ] = ⌈ divide start_ARG italic_Q [ divide start_ARG italic_i ∗ italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG : divide start_ARG italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ∗ ( italic_i + 1 ) end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG , : ] end_ARG start_ARG italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ italic_i ] end_ARG ⌋
(4)
δ Q [ i ] = max ( | Q [ i ∗ b q c w : b q ∗ ( i + 1 ) c w , : ] | ) 7 \displaystyle\delta_{Q}[i]=\frac{\max(\left|\,Q[\frac{i*b_{q}}{c_{w}}:\frac{b_%
{q}*(i+1)}{c_{w}},:]\,\right|)}{7} italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ italic_i ] = divide start_ARG roman_max ( | italic_Q [ divide start_ARG italic_i ∗ italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG : divide start_ARG italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ∗ ( italic_i + 1 ) end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG , : ] | ) end_ARG start_ARG 7 end_ARG
(5)
ワープ単位の量子化は、SageAttentionで使用されているブロック単位の量子化よりも優れた精度性能を提供する。これについては3.5 節で議論する。
図5: 平滑化Q 𝑄 Q italic_Q の前後におけるQ 𝑄 Q italic_Q の量子化値分布の例。
3.3 Smooth Q
INT4量子化の代表的な範囲は著しく制限されており、すなわち2 4 − 1 = 15 superscript 2 4 1 15 2^{4}-1=15 2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT - 1 = 15 値のみです。この制限は、注意機構においてQ , K 𝑄 𝐾
Q,K italic_Q , italic_K をINT4に量子化する際に性能を大幅に低下させます。例えば、Q , K 𝑄 𝐾
Q,K italic_Q , italic_K をINT4で量子化すると、WikiTextにおけるLlama3.1 のパープレキシティが90%以上増加し、CogvideoX が生成する動画の品質が約3倍低下します(表1 参照)。我々は実際のモデルにおけるQ , K 𝑄 𝐾
Q,K italic_Q , italic_K のデータ分布を分析しました。例えば、Llama3.1 とCogvideoX ではQ , K 𝑄 𝐾
Q,K italic_Q , italic_K においてチャンネル単位の顕著な外れ値が見られることがわかりました(図4 参照)。チャンネルごとの量子化はこのような外れ値による量子化誤差を軽減できますが、この方法はQ , K 𝑄 𝐾
Q,K italic_Q , italic_K には適用できません。なぜなら、量子化はQ K ⊤ 𝑄 superscript 𝐾 top QK^{\top} italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT の外れ値がもたらす量子化誤差を排除する方法を提案します:
Q m → = mean ( Q ) , γ ( Q ) = \displaystyle\overrightarrow{Q_{m}}=\mathrm{mean}(Q),~{}~{}\gamma(Q)= over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = roman_mean ( italic_Q ) , italic_γ ( italic_Q ) =
Q − Q m → , γ ( K ) = K − mean ( K ) , 𝑄 → subscript 𝑄 𝑚 𝛾 𝐾
𝐾 mean 𝐾 \displaystyle Q-\overrightarrow{Q_{m}},~{}~{}\gamma(K)=K-\mathrm{mean}(K), italic_Q - over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG , italic_γ ( italic_K ) = italic_K - roman_mean ( italic_K ) ,
S = Q K = 𝑆 𝑄 𝐾 absent \displaystyle S=QK= italic_S = italic_Q italic_K =
γ ( Q ) γ ( K ) + Q m → γ ( K ) 𝛾 𝑄 𝛾 𝐾 → subscript 𝑄 𝑚 𝛾 𝐾 \displaystyle\gamma(Q)\gamma(K)+\overrightarrow{Q_{m}}\gamma(K) italic_γ ( italic_Q ) italic_γ ( italic_K ) + over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG italic_γ ( italic_K )
ϕ Q ( Q ) = ψ Q ∘ γ ( Q ) , Q m → subscript italic-ϕ 𝑄 𝑄 subscript 𝜓 𝑄 𝛾 𝑄 → subscript 𝑄 𝑚
\displaystyle\phi_{Q}(Q)=\psi_{Q}\circ\gamma(Q),~{}\overrightarrow{Q_{m}} italic_ϕ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q ) = italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∘ italic_γ ( italic_Q ) , over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG
, ϕ K ( K ) = ψ K ∘ γ ( K ) \displaystyle,~{}~{}~{}~{}~{}\phi_{K}(K)=\psi_{K}\circ\gamma(K) , italic_ϕ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K ) = italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∘ italic_γ ( italic_K )
(6)
ここで、mean ( K ) mean 𝐾 \mathrm{mean}(K) roman_mean ( italic_K ) とγ ( K ) 𝛾 𝐾 \gamma(K) italic_γ ( italic_K ) はSageAttention (Zhang et al., 2024 ) で議論されており、K 𝐾 K italic_K の変換は注意スコアP ~ ~ 𝑃 \widetilde{P} over~ start_ARG italic_P end_ARG を変更しません。mean ( Q ) = 1 N ∑ t = 1 N Q [ t , : ] mean 𝑄 1 𝑁 superscript subscript 𝑡 1 𝑁 𝑄 𝑡 : \mathrm{mean}(Q)=\frac{1}{N}\sum_{t=1}^{N}Q[t,:] roman_mean ( italic_Q ) = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_Q [ italic_t , : ] は形状1 × d 1 𝑑 1\times d 1 × italic_d のベクトルです。Q 𝑄 Q italic_Q の変換については、Q m → → subscript 𝑄 𝑚 \overrightarrow{Q_{m}} over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG とK ⊤ superscript 𝐾 top K^{\top} italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT の間でGEMV(一般行列ベクトル乗算)を実行します。つまり、Q m → K ⊤ → subscript 𝑄 𝑚 superscript 𝐾 top \overrightarrow{Q_{m}}K^{\top} over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT をΔ S Δ 𝑆 \Delta S roman_Δ italic_S として計算します。このΔ S Δ 𝑆 \Delta S roman_Δ italic_S は注意の計算中にS 𝑆 S italic_S に加算されます。
最終的に、全精度のQ , K 𝑄 𝐾
Q,K italic_Q , italic_K から量子化されたQ ^ , K ^ ^ 𝑄 ^ 𝐾
\hat{Q},\hat{K} over^ start_ARG italic_Q end_ARG , over^ start_ARG italic_K end_ARG への変換は式6 のように表現できます。ここで、ψ Q subscript 𝜓 𝑄 \psi_{Q} italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT 、ψ K subscript 𝜓 𝐾 \psi_{K} italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT はQ 𝑄 Q italic_Q とK 𝐾 K italic_K のための2つの量子化器です。言い換えれば、全精度のQ , K 𝑄 𝐾
Q,K italic_Q , italic_K はチャンネル次元の平均値を減算してから量子化されます。
図5 は、CogvideoX から平滑化Qの有無による量子化されたQ 𝑄 Q italic_Q の分布の例を示しています。平滑化Qを使用することで、INT4の範囲がより均一かつ完全に利用されていることがわかります。
表1 は、Llama3.1 (Dubey et al., 2024 ) とCogvideoX (Yang et al., 2024 ) における平滑化Q+K の有無による異なる量子化方法のエンドツーエンドの指標を示しています。結果は平滑化Q+K が精度に大きな利点をもたらすことを示しています。また、表10 と表10 は、効果の順序が平滑化Q+K > > > 平滑化Q > > > 平滑化K > > > その他のベースライン であることを示しています。
アルゴリズム1 SageAttention2 の実装。
入力: 行列
Q ( FP16 ) , K ( FP16 ) , V ( FP16 ) ∈ ℝ N × d 𝑄 FP16 𝐾 FP16 𝑉 FP16
superscript ℝ 𝑁 𝑑 Q(\text{FP16}),K(\text{FP16}),V(\text{FP16})\in\mathbb{R}^{N\times d} italic_Q ( FP16 ) , italic_K ( FP16 ) , italic_V ( FP16 ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_d end_POSTSUPERSCRIPT 、ブロックサイズ
b q , b k v subscript 𝑏 𝑞 subscript 𝑏 𝑘 𝑣
b_{q},b_{kv} italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT 、ワープ数
c w subscript 𝑐 𝑤 c_{w} italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT 。
前処理: K = K − mean ( K ) 𝐾 𝐾 mean 𝐾 K=K-\mathrm{mean}(K) italic_K = italic_K - roman_mean ( italic_K ) , Q → m = mean ( Q ) subscript → 𝑄 𝑚 mean 𝑄 \overrightarrow{Q}_{m}=\mathrm{mean}(Q) over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = roman_mean ( italic_Q ) , Q = Q − Q → m 𝑄 𝑄 subscript → 𝑄 𝑚 Q=Q-\overrightarrow{Q}_{m} italic_Q = italic_Q - over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , Δ S = GEMV ( Q → m , K ⊤ ) Δ 𝑆 GEMV subscript → 𝑄 𝑚 superscript 𝐾 top \Delta S=\text{GEMV}(\overrightarrow{Q}_{m},K^{\top}) roman_Δ italic_S = GEMV ( over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) , V → m = mean ( V ) , V = V − V → m formulae-sequence subscript → 𝑉 𝑚 mean 𝑉 𝑉 𝑉 subscript → 𝑉 𝑚 \overrightarrow{V}_{m}=\mathrm{mean}(V),V=V-\overrightarrow{V}_{m} over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = roman_mean ( italic_V ) , italic_V = italic_V - over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT 。
量子化: ( δ Q , Q ^ ) = ψ Q ( Q ) subscript 𝛿 𝑄 ^ 𝑄 subscript 𝜓 𝑄 𝑄 {\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{(\delta_{Q}%
,\hat{Q})}}={\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}%
{\psi_{Q}}}(Q) ( italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , over^ start_ARG italic_Q end_ARG ) = italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q ) // INT4 ワープごと。 ( δ K , K ^ ) = ψ K ( K ) subscript 𝛿 𝐾 ^ 𝐾 subscript 𝜓 𝐾 𝐾 ~{}~{}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{(%
\delta_{K},\hat{K})}}={\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{%
rgb}{0,0,1}{\psi_{K}}}(K) ( italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , over^ start_ARG italic_K end_ARG ) = italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K ) 。
( δ V , V ^ ) = ψ V ( V ) subscript 𝛿 𝑉 ^ 𝑉 subscript 𝜓 𝑉 𝑉 ~{}~{}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{(%
\delta_{V},\hat{V})}}={\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{%
rgb}{0,0,1}{\psi_{V}}}(V) ( italic_δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT , over^ start_ARG italic_V end_ARG ) = italic_ψ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V ) 。
// FP8 チャンネルごと。
Q ^ ^ 𝑄 \hat{Q} over^ start_ARG italic_Q end_ARG を
T m = N / b q subscript 𝑇 𝑚 𝑁 subscript 𝑏 𝑞 T_{m}={N}/{b_{q}} italic_T start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_N / italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ブロック
{ Q ^ i } subscript ^ 𝑄 𝑖 \{\hat{Q}_{i}\} { over^ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } に分割する;
K ^ ^ 𝐾 \hat{K} over^ start_ARG italic_K end_ARG ,
V ^ ^ 𝑉 \hat{V} over^ start_ARG italic_V end_ARG , および
Δ S Δ 𝑆 \Delta S roman_Δ italic_S を
T n = N / b k v subscript 𝑇 𝑛 𝑁 subscript 𝑏 𝑘 𝑣 T_{n}={N}/{b_{kv}} italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = italic_N / italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT ブロック
{ K ^ i } subscript ^ 𝐾 𝑖 \{\hat{K}_{i}\} { over^ start_ARG italic_K end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } ,
{ V ^ i } subscript ^ 𝑉 𝑖 \{\hat{V}_{i}\} { over^ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } , および
{ Δ S i } Δ subscript 𝑆 𝑖 \{\Delta S_{i}\} { roman_Δ italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } に分割する;
for i = 1 𝑖 1 i=1 italic_i = 1 to T m subscript 𝑇 𝑚 T_{m} italic_T start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT do
Q ^ i subscript ^ 𝑄 𝑖 \hat{Q}_{i} over^ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT と
δ Q [ i ∗ c w : ( i + 1 ) ∗ c w ] \delta_{Q}[i*c_{w}:(i+1)*c_{w}] italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ italic_i ∗ italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT : ( italic_i + 1 ) ∗ italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ] をSMにロードする;
for j in [1,
T n subscript 𝑇 𝑛 T_{n} italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ]
do
K ^ j subscript ^ 𝐾 𝑗 \hat{K}_{j} over^ start_ARG italic_K end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ,
V ^ j subscript ^ 𝑉 𝑗 \hat{V}_{j} over^ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , および
δ K [ j ] subscript 𝛿 𝐾 delimited-[] 𝑗 \delta_{K}[j] italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT [ italic_j ] をSMにロードする;
w = range ( c w ) , s t = w ∗ c w formulae-sequence 𝑤 range subscript 𝑐 𝑤 𝑠 𝑡 𝑤 subscript 𝑐 𝑤 w=\text{range}(c_{w}),st=w*c_{w} italic_w = range ( italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ) , italic_s italic_t = italic_w ∗ italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT
S i j [ s t : s t + c w ] = Matmul ( Q ^ i [ s t : s t + c w ] , K ^ j ⊤ ) × δ Q [ s t + w ] × δ K [ j ] S_{i}^{j}[st:st+c_{w}]=\mathrm{Matmul}(\hat{Q}_{i}[st:st+c_{w}],\hat{K}_{j}^{%
\top})\times\delta_{Q}[st+w]\times\delta_{K}[j] italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT [ italic_s italic_t : italic_s italic_t + italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ] = roman_Matmul ( over^ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ italic_s italic_t : italic_s italic_t + italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ] , over^ start_ARG italic_K end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) × italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ italic_s italic_t + italic_w ] × italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT [ italic_j ] + Δ S j Δ subscript 𝑆 𝑗 \Delta S_{j} roman_Δ italic_S start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; // c w subscript 𝑐 𝑤 c_{w} italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ワープで並列化。
m i j = max ( m i j − 1 , rowmax ( S i j ) ) superscript subscript 𝑚 𝑖 𝑗 max superscript subscript 𝑚 𝑖 𝑗 1 rowmax superscript subscript 𝑆 𝑖 𝑗 m_{i}^{j}=\mathrm{max}(m_{i}^{j-1},\mathrm{rowmax}(S_{i}^{j})) italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT = roman_max ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT , roman_rowmax ( italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ) ,
P ~ i j = exp ( S i j − m i j ) superscript subscript ~ 𝑃 𝑖 𝑗 exp superscript subscript 𝑆 𝑖 𝑗 superscript subscript 𝑚 𝑖 𝑗 \widetilde{P}_{i}^{j}=\mathrm{exp}(S_{i}^{j}-m_{i}^{j}) over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT = roman_exp ( italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ,
l i j = e m i j − 1 − m i j + rowsum ( P ~ i j ) superscript subscript 𝑙 𝑖 𝑗 superscript 𝑒 superscript subscript 𝑚 𝑖 𝑗 1 superscript subscript 𝑚 𝑖 𝑗 rowsum superscript subscript ~ 𝑃 𝑖 𝑗 l_{i}^{j}=e^{m_{i}^{j-1}-m_{i}^{j}}+\mathrm{rowsum}(\widetilde{P}_{i}^{j}) italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT = italic_e start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT + roman_rowsum ( over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ;
O i j = diag ( e m i j − 1 − m i j ) − 1 O i j − 1 + superscript subscript 𝑂 𝑖 𝑗 limit-from diag superscript superscript 𝑒 superscript subscript 𝑚 𝑖 𝑗 1 superscript subscript 𝑚 𝑖 𝑗 1 superscript subscript 𝑂 𝑖 𝑗 1 O_{i}^{j}=\mathrm{diag}(e^{m_{i}^{j-1}-m_{i}^{j}})^{-1}O_{i}^{j-1}+ italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT = roman_diag ( italic_e start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT + Matmul ( ( P ~ i j ∗ 448 ) .to(FP8.e4m3) , V j ) Matmul superscript subscript ~ 𝑃 𝑖 𝑗 448 .to(FP8.e4m3) subscript 𝑉 𝑗 \mathrm{Matmul}((\widetilde{P}_{i}^{j}*448)\text{.to(FP8.e4m3)},V_{j}) roman_Matmul ( ( over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ∗ 448 ) .to(FP8.e4m3) , italic_V start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ;
end for
δ V subscript 𝛿 𝑉 \delta_{V} italic_δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT をSMにロードする;
O i = diag ( l i T n ) − 1 O i T n / 448 ∗ δ V subscript 𝑂 𝑖 diag superscript superscript subscript 𝑙 𝑖 subscript 𝑇 𝑛 1 superscript subscript 𝑂 𝑖 subscript 𝑇 𝑛 448 subscript 𝛿 𝑉 O_{i}=\mathrm{diag}(l_{i}^{T_{n}})^{-1}O_{i}^{T_{n}}/448*\delta_{V} italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_diag ( italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT / 448 ∗ italic_δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ;
O i subscript 𝑂 𝑖 O_{i} italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT を書き込む ;
end for
O = { O i } 𝑂 subscript 𝑂 𝑖 O=\{O_{i}\} italic_O = { italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } ,
O = O + V → m 𝑂 𝑂 subscript → 𝑉 𝑚 O=O+\overrightarrow{V}_{m} italic_O = italic_O + over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT 。
表6: CogvideoX モデルの実際のテンソルにおけるV 𝑉 V italic_V の平滑化の有無による精度の例。
Smooth V
Cos Sim ↑ ↑ \uparrow ↑
Relative L1 ↓ ↓ \downarrow ↓
RMSE ↓ ↓ \downarrow ↓
✗
98.25%
0.1980
0.2387
✓
99.75%
0.0406
0.0773
図6: FP22データ型で表現されたP ~ ~ 𝑃 \widetilde{P} over~ start_ARG italic_P end_ARG の1行とV 𝑉 V italic_V の1列の内積精度の例。
表7: mma(f8f8f32) のFP8行列乗算命令の誤差。
Precision of Accumulated Value
E8M13
E8M23
Error compared to FP32
0
| | | | mma(f32.f16.f16.f32) - mma(f32.f8.f8.f32)| | | |
3.4 Smooth V
SageAttentionの欠点である、P V 𝑃 𝑉 PV italic_P italic_V に対するFP16累算器を持つFP16 MatmulがRTX4090のようなGPUでのみ効果的であることを避けるため、我々はP 𝑃 P italic_P とV 𝑉 V italic_V をFP8に量子化してFP8テンソルコアの普遍的な加速を活用するアプローチを採用した。しかし、我々はAda アーキテクチャ上のmma(f32f8f8f32) 命令の累算器が実際にはFP22、具体的には1符号ビット、8指数ビット、13仮数ビットであることを発見した。具体的には、mma(f32f8f8f32) 命令C = A B + D 𝐶 𝐴 𝐵 𝐷 C=AB+D italic_C = italic_A italic_B + italic_D において、A , B 𝐴 𝐵
A,B italic_A , italic_B はFP8データ型のテンソルで、C , D 𝐶 𝐷
C,D italic_C , italic_D はFP32データ型のテンソルである。我々はA , B 𝐴 𝐵
A,B italic_A , italic_B をゼロに初期化し、D 𝐷 D italic_D を変化させて累算器のデータ型をテストした。表7 に示すように、D 𝐷 D italic_D が1符号ビット、8指数ビット、13仮数ビットで初期化されると、C 𝐶 C italic_C の値はmma(f16f16f32) 命令の結果と一致する。しかし、D 𝐷 D italic_D が13ビット以上の仮数ビットで初期化されると、C 𝐶 C italic_C の誤差はmma(f32f16f16f32) とmma(f32f8f8f32) の結果の差に対応する。
したがって、FP8に量子化された行列P ~ ~ 𝑃 \widetilde{P} over~ start_ARG italic_P end_ARG とV 𝑉 V italic_V の行列乗算は、FP32累算器を使用する場合と比較して、ある程度の精度損失を被る。この精度損失を可能な限り軽減するために、我々はV 𝑉 V italic_V を滑らかにすることを提案する:
γ ( V ) = V − mean ( V ) , 𝛾 𝑉 𝑉 mean 𝑉 \displaystyle\gamma(V)=V-\mathrm{mean}(V), italic_γ ( italic_V ) = italic_V - roman_mean ( italic_V ) ,
mean ( V ) = V m → mean 𝑉 → subscript 𝑉 𝑚 \displaystyle~{}~{}~{}~{}\mathrm{mean}(V)=\overrightarrow{V_{m}} roman_mean ( italic_V ) = over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG
ϕ V ( V ) = subscript italic-ϕ 𝑉 𝑉 absent \displaystyle\phi_{V}(V)= italic_ϕ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V ) =
ψ V ∘ γ , V m → subscript 𝜓 𝑉 𝛾 → subscript 𝑉 𝑚
\displaystyle\psi_{V}\circ\gamma,~{}\overrightarrow{V_{m}} italic_ψ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∘ italic_γ , over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG
(7)
図6 に示すように、この戦略は以下の理由によりP ~ V ~ 𝑃 𝑉 \widetilde{P}V over~ start_ARG italic_P end_ARG italic_V の値に対するFP22の精度を向上させる:P ~ ~ 𝑃 \widetilde{P} over~ start_ARG italic_P end_ARG の各行は0から1の値範囲にわたり、V 𝑉 V italic_V の各列は一貫してチャンネル方向の外れ値を持ち、それらは排他的に正または負で、通常8から9の範囲にある。結果として、P ~ V ~ 𝑃 𝑉 \widetilde{P}V over~ start_ARG italic_P end_ARG italic_V の値はかなり大きくなる可能性がある。しかし、浮動小数点数の表現範囲は一様ではなく、ゼロ付近でより密である。したがって、チャンネル次元に沿って平均をV 𝑉 V italic_V から減算することで、P ~ V ~ 𝑃 𝑉 \widetilde{P}V over~ start_ARG italic_P end_ARG italic_V の値はゼロに近くなり、より高い表現精度が得られる。我々は、このような説明と戦略がコミュニティによって報告された多くの問題を解決できると考えている。
さらに、アテンション計算の正確性を維持するためには、V m → → subscript 𝑉 𝑚 \overrightarrow{V_{m}} over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG をO 𝑂 O italic_O の最終計算に加えるだけで十分である:O = O + V m → 𝑂 𝑂 → subscript 𝑉 𝑚 O=O+\overrightarrow{V_{m}} italic_O = italic_O + over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG 。これは、P ~ ~ 𝑃 \widetilde{P} over~ start_ARG italic_P end_ARG 行列の各行の合計が1に等しいため、P ~ V m → = V m → ~ 𝑃 → subscript 𝑉 𝑚 → subscript 𝑉 𝑚 \widetilde{P}\overrightarrow{V_{m}}=\overrightarrow{V_{m}} over~ start_ARG italic_P end_ARG over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG となるからである。
言い換えれば、この方法はV 𝑉 V italic_V を2つの部分に分解する:V m → → subscript 𝑉 𝑚 \overrightarrow{V_{m}} over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG とV 𝑉 V italic_V 。V 𝑉 V italic_V については、各列の値をゼロの周りに中心化し、これにより量子化されたP ~ ~ 𝑃 \widetilde{P} over~ start_ARG italic_P end_ARG 行列の行と量子化されたV 𝑉 V italic_V 行列の列との内積結果がゼロに近くなる。これによりFP22の表現がより正確になる。一方、V m → → subscript 𝑉 𝑚 \overrightarrow{V_{m}} over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG はFP16で保持され、最後にO 𝑂 O italic_O に加算されるため、V m → → subscript 𝑉 𝑚 \overrightarrow{V_{m}} over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG 部分の精度損失を引き起こさない。
表6 は、CogvideoX からサンプリングされた実際のテンソルに対する、V 𝑉 V italic_V の平滑化の有無によるアテンション精度を示している。これは、Q , K 𝑄 𝐾
Q,K italic_Q , italic_K をINT4に、P ~ , V ~ 𝑃 𝑉
\widetilde{P},V over~ start_ARG italic_P end_ARG , italic_V をFP8に量子化する際に、Vを平滑化することでSageAttention2 の精度を向上させることができることを示している。
3.5 Quantization for Q, K, P, V
Q , K 𝑄 𝐾
Q,K italic_Q , italic_K の量子化。 我々は、ワープ単位でのψ Q ( Q ) subscript 𝜓 𝑄 𝑄 \psi_{Q}(Q) italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q ) とψ K ( K ) subscript 𝜓 𝐾 𝐾 \psi_{K}(K) italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K ) を提案する。これは第一に、Q K ⊤ 𝑄 superscript 𝐾 top QK^{\top} italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT の内部軸のスケール係数を非量子化に使用できないため、チャンネル単位の量子化が実現不可能であるためである(Xiao et al., 2023a ) 。第二に、表3 と表3 に示すように、我々はCogvideoX のすべての層にわたる実際のQ , K , V 𝑄 𝐾 𝑉
Q,K,V italic_Q , italic_K , italic_V を使用して、トークン単位、ワープ単位、ブロック単位、テンソル単位の粒度でのINT4量子化の平均精度と最悪ケースの精度を比較した。結果は、ワープ単位の量子化の精度がトークン単位に非常に近く、ブロック単位やテンソル単位よりもはるかに優れていることを示している。さらに、セクション3.2 で議論したように、ワープ単位の量子化はトークン単位よりも非量子化のオーバーヘッドが少ない。
P ~ , V ~ 𝑃 𝑉
\widetilde{P},V over~ start_ARG italic_P end_ARG , italic_V の量子化。 我々はψ P ( P ~ ) subscript 𝜓 𝑃 ~ 𝑃 \psi_{P}(\widetilde{P}) italic_ψ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( over~ start_ARG italic_P end_ARG ) とψ V ( V ) subscript 𝜓 𝑉 𝑉 \psi_{V}(V) italic_ψ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V ) にFP8、具体的にはE4M3データ型を選択した。これには二つの理由がある。第一に、ほとんどのGPUがFP8行列乗算演算をサポートするテンソルコアを持っており、これはFP16を使用する場合の2倍の速度である。第二に、表5 と表5 は、CogvideoX のすべての層にわたる実際のQ , K , V 𝑄 𝐾 𝑉
Q,K,V italic_Q , italic_K , italic_V を使用して、P ~ , V ~ 𝑃 𝑉
\widetilde{P},V over~ start_ARG italic_P end_ARG , italic_V に使用される異なるデータ型の平均精度と最悪精度を示している。E4M3を使用する精度がFP16を使用する場合に非常に近く、E5M2やINT8よりも優れていることがわかる。
我々は、ψ P ( P ~ ) subscript 𝜓 𝑃 ~ 𝑃 \psi_{P}(\widetilde{P}) italic_ψ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( over~ start_ARG italic_P end_ARG ) をブロック単位で、ψ V ( V ) subscript 𝜓 𝑉 𝑉 \psi_{V}(V) italic_ψ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V ) をチャンネル単位で使用することを提案する。これには三つの理由がある。第一に、P ~ ~ 𝑃 \widetilde{P} over~ start_ARG italic_P end_ARG のチャンネル単位量子化とV 𝑉 V italic_V のトークン単位量子化は、非量子化に外部軸のスケール係数が必要なため実現不可能である。第二に、P ~ = exp ( S i − rowmax ( S i ) ) ~ 𝑃 exp subscript 𝑆 𝑖 rowmax subscript 𝑆 𝑖 \widetilde{P}=\mathrm{exp}(S_{i}-\mathrm{rowmax}(S_{i})) over~ start_ARG italic_P end_ARG = roman_exp ( italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - roman_rowmax ( italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) であり、ここでS i subscript 𝑆 𝑖 S_{i} italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT はQ 𝑄 Q italic_Q とK ⊤ superscript 𝐾 top K^{\top} italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT のブロックの行列乗算結果であり、P ~ ~ 𝑃 \widetilde{P} over~ start_ARG italic_P end_ARG の各行の最大値は1である。したがって、我々はブロックP ~ ~ 𝑃 \widetilde{P} over~ start_ARG italic_P end_ARG に単一の静的スケールs = 1 448 𝑠 1 448 s=\frac{1}{448} italic_s = divide start_ARG 1 end_ARG start_ARG 448 end_ARG を割り当てることができ、その精度はトークン単位の量子化と等しい 。第三に、チャンネル単位の量子化はV 𝑉 V italic_V のチャンネル単位の外れ値に対処できる。
精度指標。 我々は、量子化された注意出力O ′ superscript 𝑂 ′ O^{\prime} italic_O start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT を全精度の注意出力O 𝑂 O italic_O と比較するために3つの指標を使用する:まず、O ′ superscript 𝑂 ′ O^{\prime} italic_O start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT とO 𝑂 O italic_O を1 × n 1 𝑛 1\times n 1 × italic_n の形状のベクトルに平坦化する。次に、コサイン類似度:C o s S i m = ∑ O O ′ / ∑ O 2 ∑ O ′ 2 𝐶 𝑜 𝑠 𝑆 𝑖 𝑚 𝑂 superscript 𝑂 ′ superscript 𝑂 2 superscript 𝑂 ′ 2
CosSim=\sum OO^{\prime}/\smash{\sqrt{\sum O^{2}}}\smash{\sqrt{\sum O^{\prime 2%
}}} italic_C italic_o italic_s italic_S italic_i italic_m = ∑ italic_O italic_O start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT / square-root start_ARG ∑ italic_O start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG square-root start_ARG ∑ italic_O start_POSTSUPERSCRIPT ′ 2 end_POSTSUPERSCRIPT end_ARG 、相対L1距離:L 1 = ∑ | O − O ′ | / ∑ | O | 𝐿 1 𝑂 superscript 𝑂 ′ 𝑂 L1=\sum|O-O^{\prime}|/\sum|O| italic_L 1 = ∑ | italic_O - italic_O start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | / ∑ | italic_O | 、二乗平均平方根誤差:R M S E = ( 1 / n ) ∑ ( O − O ′ ) 2 𝑅 𝑀 𝑆 𝐸 1 𝑛 superscript 𝑂 superscript 𝑂 ′ 2 RMSE=\sqrt{(1/n)\sum(O-O^{\prime})^{2}} italic_R italic_M italic_S italic_E = square-root start_ARG ( 1 / italic_n ) ∑ ( italic_O - italic_O start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG 。
表8: SageAttention2 の2つのカーネル実装。
Kernel
ψ Q ( Q ) subscript 𝜓 𝑄 𝑄 \psi_{Q}(Q) italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q ) , ψ K ( K ) subscript 𝜓 𝐾 𝐾 \psi_{K}(K) italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K )
ψ P ( P ~ ) subscript 𝜓 𝑃 ~ 𝑃 \psi_{P}(\widetilde{P}) italic_ψ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( over~ start_ARG italic_P end_ARG ) , ψ V ( V ) subscript 𝜓 𝑉 𝑉 \psi_{V}(V) italic_ψ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V )
SageAttn2-4b
INT4, per-warp
FP8, per-block; FP8, per-channel
SageAttn2-8b
INT8, per-warp
FP8, per-block; FP8, per-channel
図7: Llama3.1 とCogvideoX の異なる入力に対する、異なる層とタイムステップにおけるSageAttn-4b のc o s s i m ∗ ( 1 − L 1 ) 𝑐 𝑜 𝑠 𝑠 𝑖 𝑚 1 𝐿 1 cossim*(1-L1) italic_c italic_o italic_s italic_s italic_i italic_m ∗ ( 1 - italic_L 1 ) の平均と標準偏差。
4 Experiment
4.1 Setup
モデル。
我々は、言語、画像、動画生成の代表的なモデルの多様なセットにわたってSageAttention2 の有効性を検証する。
具体的には、5つのモデルで実験を行う:テキスト生成にはLlama2 (7B) (Touvron et al., 2023 ) 、Llama3.1 (8B) (Dubey et al., 2024 ) 、GLM4 (9B) (GLM et al., 2024 ) を、テキストから動画生成にはCogvideoX (2B) (Yang et al., 2024 ) とOpen-Sora (Zheng et al., 2024 ) を、テキストから画像生成にはFlux (schnell) (Black Forest Labs, 2023 ) を、画像分類にはTIMM (Wightman, 2019 ) を使用する。
データセット。
テキスト生成モデルは4つのゼロショットタスクで評価される:WikiText (Merity et al., 2022 ) でモデルの予測信頼度を評価し、LAMBADA (Paperno et al., 2016 ) で文脈理解を評価し、MMLU (Hendrycks et al., 2020 ) で様々な分野の知識を測定し、Longbench (Bai et al., 2023 ) で二言語、マルチタスク、長文脈理解能力の包括的評価を行う。
テキストから動画生成モデルはopen-sora (Zheng et al., 2024 ) のプロンプトセットを用いて評価される。
Flux はMJHQ-30K (Li et al., 2024 ) で評価される。
TIMM は3つの画像データセットで評価される:ImageNet (Deng et al., 2009 ) 、ImageNet-Sketch (Sketch) (Wang et al., 2019 ) 、ImageNet-Rendition (ImageNet-r) (Hendrycks et al., 2021 ) である。
評価指標。
テキスト生成モデルについては、WikiTextにはパープレキシティ (ppl.) (Jelinek et al., 1977 ) を、LAMBADAとMMLUには正確性 (Acc.)を、Longbenchにはスコア (Bai et al., 2023 ) を使用する。
テキストから動画生成モデルについては、Zhao et al. (2024 ) に従い、生成された動画の品質を5つの指標で評価する:テキストと動画の整合性を測るCLIPSIMとCLIP-Temp (CLIP-T) (Liu et al., 2024 ) 、動画の美的品質と技術的品質をそれぞれ評価する(VQA-a)と(VQA-t)、時間的一貫性を評価するFlow-score (FScore) (Wu et al., 2023 ) である。
テキストから画像生成モデルについては、生成された画像をMJHQ-30Kデータセットの画像と3つの側面で比較する:忠実度評価にはFID (Heusel et al., 2017 ) とsFID (Salimans et al., 2016 ) を、テキストと画像の整合性にはClipscore (CLIP) (Hessel et al., 2021 ) を、人間の選好にはImageReward (IR) (Xu et al., 2024 ) を使用する。
TIMM については、正確性を使用する。
実装の詳細。 我々はCUDAを用いてアテンションカーネルを実装し、RTX4090とL20 GPUを搭載したUbuntu 22.04サーバーで実験を行う。
ベースライン。
(1) SmoothAttn 。Qserve (Lin et al., 2024 ) に従い、行列Q , K 𝑄 𝐾
Q,K italic_Q , italic_K に対して平滑化因子α = 0.5 𝛼 0.5 \alpha=0.5 italic_α = 0.5 でスムースクォンタイゼーションを適用する。
(2) HadmdAttn 。Quarot (Ashkboos et al., 2024 ) に従い、INT4量子化の前に行列Q , K 𝑄 𝐾
Q,K italic_Q , italic_K にランダムアダマール変換を適用する。
SageAttn-mix のハイパーパラメータ。 我々はξ 𝜉 \xi italic_ξ として、Llama2 に50%、CogvideoX に60%、Open-Sora に25%、Llama3.1 とGLM4 に30%、Flux とTIMM に0%を使用する。
図14: CogvideoX からの比較例、プロンプトはopen-soraプロンプトセットからサンプリングされている。
表11: テキスト、画像、動画生成モデルにおけるエンドツーエンドの指標損失。
4.2 Speed and Accuracy of Kernels
速度。 我々は、headdim=64、headdim=128、およびheaddim=256の構成を用いて、因果マスクの有無両方の場合で、SageAttention2 の速度をベースラインと比較する実験を行った(Vaswani, 2017 ) 。具体的には、図8 、図9 、および図10 は、RTX4090上での様々なシーケンス長におけるSageAttention2 とベースラインの速度を示している。これらの結果は、SageAttention2 が最大485 TOPSを達成し、FlashAttention2より3.1倍 速く、xformersより5.4倍 速いことを示している。
図11 、図12 、および図13 は、L20 GPU上での結果を示しており、SageAttention2 が最大288 TOPSを達成し、FlashAttention2より2.7倍 速く、xformersより4.6倍 速いことを示している。
精度。 表10 および表10 は、CogvideoX のすべての層にわたるINT4 Q , K 𝑄 𝐾
Q,K italic_Q , italic_K およびFP8 P , V 𝑃 𝑉
P,V italic_P , italic_V を用いた異なる手法の平均精度と最悪精度を示している。結果は、SageAttn-4b の精度が他のベースラインよりも優れていることを示している。
図15: 異なる割合のSageAttn-8b を使用したエンドツーエンドのパフォーマンス。
4.3 End-to-end Performance
メトリクスの損失。 我々は、フル精度のアテンションと比較して、SageAttention2 を使用した様々なモデルのエンドツーエンドメトリクスを評価した。
詳細な評価結果は、Llama2 、Llama3.1 、GLM4 、CogvideoX 、Open-Sora 、Flux 、およびTIMM について、表11 に示されている。結果は、SageAttn-4b がすべてのベースラインを上回り、すべてのモデルにおいてエンドツーエンドの精度のほとんどを維持していることを示している。さらに、適応的量子化技術を用いることで、SageAttn-mix はフル精度のアテンションに匹敵する性能を達成している。
4.4 Ablation Study
適応的量子化。 適応的量子化におけるSageAttn-8b の異なる比率の使用の影響を分析するために、表15 はLlama3.1 のパープレキシティの変化を様々なSageAttn-8b の比率で示している。SageAttn-4b のみを使用した場合でも、全体的なエンドツーエンドの表現は十分に良好であり、使用するSageAttn-8b の比率が高いほど、精度はフル精度のアテンションを使用した場合に近づくことが観察できる。
Q, Kの平滑化のオーバーヘッド。 Q , K 𝑄 𝐾
Q,K italic_Q , italic_K の平滑化のオーバーヘッドには、Q − Q m → 𝑄 → subscript 𝑄 𝑚 Q-\overrightarrow{Q_{m}} italic_Q - over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG 、K − mean ( K ) 𝐾 mean 𝐾 K-\mathrm{mean}(K) italic_K - roman_mean ( italic_K ) 、およびQ m → K ⊤ → subscript 𝑄 𝑚 superscript 𝐾 top \overrightarrow{Q_{m}}K^{\top} over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT のみが含まれることに注意されたい。これらの速度オーバーヘッドは、アテンションカーネルの約3.5%を占める。
5 Conclusion and Future Work
我々は、注意機構の効率的かつ正確な4ビット量子化手法であるSageAttention2 を提案する。まず、( Q , K ) 𝑄 𝐾 (Q,K) ( italic_Q , italic_K ) をワープレベルの粒度でINT4に量子化し、( P ~ , V ) ~ 𝑃 𝑉 (\widetilde{P},V) ( over~ start_ARG italic_P end_ARG , italic_V ) をFP8に量子化することを提案する。次に、行列Q 𝑄 Q italic_Q とV 𝑉 V italic_V を平滑化する手法を提案し、INT4のQ K 𝑄 𝐾 QK italic_Q italic_K とFP8のP V 𝑃 𝑉 PV italic_P italic_V を用いた注意機構の精度を向上させる。さらに、タイムステップと層にわたる量子化精度を分析し、様々なモデルにおけるエンドツーエンドの指標を保証する適応的量子化手法を提案する。我々の実装は、RTX4090上でFlashAttention2とxformersと比較して、それぞれ約3.1倍 および5.4倍 高速である。広範なテストにより、本稿のアプローチが言語、画像、動画生成を含む様々なモデルにおいてエンドツーエンドの指標を維持することが確認された。
今後の課題。 Hopper アーキテクチャ上でのP ~ V ~ 𝑃 𝑉 \widetilde{P}V over~ start_ARG italic_P end_ARG italic_V のFP16アキュムレータを用いたFP8 MatMulの実装とSageAttention2 の実装は今後の課題として残されている。
References
Ashkboos et al. (2024)
Ashkboos, S., Mohtashami, A., Croci, M. L., Li, B., Cameron, P., Jaggi, M., Alistarh, D., Hoefler, T., and Hensman, J.
Quarot: Outlier-free 4-bit inference in rotated llms, 2024.
URL https://arxiv.org/abs/2404.00456 .
Bai et al. (2023)
Bai, Y., Lv, X., Zhang, J., Lyu, H., Tang, J., Huang, Z., Du, Z., Liu, X., Zeng, A., Hou, L., et al.
Longbench: A bilingual, multitask benchmark for long context understanding.
arXiv preprint arXiv:2308.14508 , 2023.
Black Forest Labs (2023)
Black Forest Labs.
Flux.
https://github.com/black-forest-labs/flux , 2023.
Chen et al. (2023)
Chen, Y., Qian, S., Tang, H., Lai, X., Liu, Z., Han, S., and Jia, J.
Longlora: Efficient fine-tuning of long-context large language models.
arXiv preprint arXiv:2309.12307 , 2023.
Choromanski et al. (2020)
Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J., Mohiuddin, A., Kaiser, L., et al.
Rethinking attention with performers.
arXiv preprint arXiv:2009.14794 , 2020.
Chu et al. (2021)
Chu, X., Tian, Z., Wang, Y., Zhang, B., Ren, H., Wei, X., Xia, H., and Shen, C.
Twins: Revisiting the design of spatial attention in vision transformers.
Advances in neural information processing systems , 34:9355–9366, 2021.
Dao (2023)
Dao, T.
Flashattention-2: Faster attention with better parallelism and work partitioning.
arXiv preprint arXiv:2307.08691 , 2023.
Dao et al. (2022)
Dao, T., Fu, D., Ermon, S., Rudra, A., and Ré, C.
Flashattention: Fast and memory-efficient exact attention with io-awareness.
Advances in Neural Information Processing Systems , 35:16344–16359, 2022.
Deng et al. (2009)
Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei-Fei, L.
Imagenet: A large-scale hierarchical image database.
In 2009 IEEE conference on computer vision and pattern recognition , pp. 248–255. Ieee, 2009.
Dubey et al. (2024)
Dubey, A., Jauhri, A., Pandey, A., Kadian, A., Al-Dahle, A., Letman, A., Mathur, A., Schelten, A., Yang, A., Fan, A., et al.
The llama 3 herd of models.
arXiv preprint arXiv:2407.21783 , 2024.
Fu et al. (2024)
Fu, T., Huang, H., Ning, X., Zhang, G., Chen, B., Wu, T., Wang, H., Huang, Z., Li, S., Yan, S., Dai, G., Yang, H., and Wang, Y.
Moa: Mixture of sparse attention for automatic large language model compression, 2024.
URL https://arxiv.org/abs/2406.14909 .
Gao et al. (2024)
Gao, Y., Zeng, Z., Du, D., Cao, S., So, H. K.-H., Cao, T., Yang, F., and Yang, M.
Seerattention: Learning intrinsic sparse attention in your llms, 2024.
URL https://arxiv.org/abs/2410.13276 .
gkamradt (2023)
gkamradt.
Llmtest needle in a haystack - pressure testing llms.
https://github.com/gkamradt/LLMTest_NeedleInAHaystack , 2023.
GLM et al. (2024)
GLM, T., Zeng, A., Xu, B., Wang, B., Zhang, C., Yin, D., Rojas, D., Feng, G., Zhao, H., Lai, H., Yu, H., Wang, H., Sun, J., Zhang, J., Cheng, J., Gui, J., Tang, J., Zhang, J., Li, J., Zhao, L., Wu, L., Zhong, L., Liu, M., Huang, M., Zhang, P., Zheng, Q., Lu, R., Duan, S., Zhang, S., Cao, S., Yang, S., Tam, W. L., Zhao, W., Liu, X., Xia, X., Zhang, X., Gu, X., Lv, X., Liu, X., Liu, X., Yang, X., Song, X., Zhang, X., An, Y., Xu, Y., Niu, Y., Yang, Y., Li, Y., Bai, Y., Dong, Y., Qi, Z., Wang, Z., Yang, Z., Du, Z., Hou, Z., and Wang, Z.
Chatglm: A family of large language models from glm-130b to glm-4 all tools, 2024.
Hendrycks et al. (2020)
Hendrycks, D., Burns, C., Basart, S., Zou, A., Mazeika, M., Song, D., and Steinhardt, J.
Measuring massive multitask language understanding.
2020.
Hendrycks et al. (2021)
Hendrycks, D., Basart, S., Mu, N., Kadavath, S., Wang, F., Dorundo, E., Desai, R., Zhu, T., Parajuli, S., Guo, M., Song, D., Steinhardt, J., and Gilmer, J.
The many faces of robustness: A critical analysis of out-of-distribution generalization.
ICCV , 2021.
Hessel et al. (2021)
Hessel, J., Holtzman, A., Forbes, M., Le Bras, R., and Choi, Y.
Clipscore: A reference-free evaluation metric for image captioning.
In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing , pp. 7514–7528, 2021.
Heusel et al. (2017)
Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S.
Gans trained by a two time-scale update rule converge to a local nash equilibrium.
Advances in neural information processing systems , 30, 2017.
Jelinek et al. (1977)
Jelinek, F., Mercer, R. L., Bahl, L. R., and Baker, J. K.
Perplexity—a measure of the difficulty of speech recognition tasks.
The Journal of the Acoustical Society of America , 62(S1):S63–S63, 1977.
Jiang et al. (2024)
Jiang, H., Li, Y., Zhang, C., Wu, Q., Luo, X., Ahn, S., Han, Z., Abdi, A. H., Li, D., Lin, C.-Y., et al.
Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention.
arXiv preprint arXiv:2407.02490 , 2024.
Katharopoulos et al. (2020)
Katharopoulos, A., Vyas, A., Pappas, N., and Fleuret, F.
Transformers are rnns: Fast autoregressive transformers with linear attention.
In International conference on machine learning , pp. 5156–5165. PMLR, 2020.
Lefaudeux et al. (2022)
Lefaudeux, B., Massa, F., Liskovich, D., Xiong, W., Caggiano, V., Naren, S., Xu, M., Hu, J., Tintore, M., Zhang, S., Labatut, P., Haziza, D., Wehrstedt, L., Reizenstein, J., and Sizov, G.
xformers: A modular and hackable transformer modelling library.
https://github.com/facebookresearch/xformers , 2022.
Li et al. (2024)
Li, D., Kamko, A., Akhgari, E., Sabet, A., Xu, L., and Doshi, S.
Playground v2.5: Three insights towards enhancing aesthetic quality in text-to-image generation, 2024.
(24)
Li, K., Wang, Y., Peng, G., Song, G., Liu, Y., Li, H., and Qiao, Y.
Uniformer: Unified transformer for efficient spatial-temporal representation learning.
In International Conference on Learning Representations .
Lin et al. (2024)
Lin, Y., Tang, H., Yang, S., Zhang, Z., Xiao, G., Gan, C., and Han, S.
Qserve: W4a8kv4 quantization and system co-design for efficient llm serving, 2024.
URL https://arxiv.org/abs/2405.04532 .
Liu et al. (2024)
Liu, Y., Cun, X., Liu, X., Wang, X., Zhang, Y., Chen, H., Liu, Y., Zeng, T., Chan, R., and Shan, Y.
Evalcrafter: Benchmarking and evaluating large video generation models.
In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition , pp. 22139–22149, 2024.
Liu et al. (2021)
Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., Lin, S., and Guo, B.
Swin transformer: Hierarchical vision transformer using shifted windows.
In Proceedings of the IEEE/CVF international conference on computer vision , pp. 10012–10022, 2021.
Merity et al. (2022)
Merity, S., Xiong, C., Bradbury, J., and Socher, R.
Pointer sentinel mixture models.
In International Conference on Learning Representations , 2022.
Milakov & Gimelshein (2018)
Milakov, M. and Gimelshein, N.
Online normalizer calculation for softmax.
arXiv preprint arXiv:1805.02867 , 2018.
Paperno et al. (2016)
Paperno, D., Kruszewski, G., Lazaridou, A., Pham, N.-Q., Bernardi, R., Pezzelle, S., Baroni, M., Boleda, G., and Fernández, R.
The lambada dataset: Word prediction requiring a broad discourse context.
In Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers) , pp. 1525–1534, 2016.
Salimans et al. (2016)
Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X.
Improved techniques for training gans.
Advances in neural information processing systems , 29, 2016.
Shah et al. (2024)
Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., and Dao, T.
Flashattention-3: Fast and accurate attention with asynchrony and low-precision.
arXiv preprint arXiv:2407.08608 , 2024.
Touvron et al. (2023)
Touvron, H., Martin, L., Stone, K., Albert, P., Almahairi, A., Babaei, Y., Bashlykov, N., Batra, S., Bhargava, P., Bhosale, S., et al.
Llama 2: Open foundation and fine-tuned chat models.
arXiv preprint arXiv:2307.09288 , 2023.
Vaswani (2017)
Vaswani, A.
Attention is all you need.
Advances in Neural Information Processing Systems , 2017.
Venkataramanan et al. (2023)
Venkataramanan, S., Ghodrati, A., Asano, Y. M., Porikli, F., and Habibian, A.
Skip-attention: Improving vision transformers by paying less attention.
arXiv preprint arXiv:2301.02240 , 2023.
Wang et al. (2019)
Wang, H., Ge, S., Lipton, Z., and Xing, E. P.
Learning robust global representations by penalizing local predictive power.
Advances in Neural Information Processing Systems , 32, 2019.
Wang et al. (2020)
Wang, S., Li, B. Z., Khabsa, M., Fang, H., and Ma, H.
Linformer: Self-attention with linear complexity.
arXiv preprint arXiv:2006.04768 , 2020.
Wightman (2019)
Wightman, R.
Pytorch image models.
https://github.com/rwightman/pytorch-image-models , 2019.
Wu et al. (2023)
Wu, H., Zhang, E., Liao, L., Chen, C., Hou, J., Wang, A., Sun, W., Yan, Q., and Lin, W.
Exploring video quality assessment on user generated contents from aesthetic and technical perspectives.
In Proceedings of the IEEE/CVF International Conference on Computer Vision , pp. 20144–20154, 2023.
Xiao et al. (2024)
Xiao, C., Zhang, P., Han, X., Xiao, G., Lin, Y., Zhang, Z., Liu, Z., and Sun, M.
Infllm: Training-free long-context extrapolation for llms with an efficient context memory.
In First Workshop on Long-Context Foundation Models@ ICML 2024 , 2024.
Xiao et al. (2023a)
Xiao, G., Lin, J., Seznec, M., Wu, H., Demouth, J., and Han, S.
Smoothquant: Accurate and efficient post-training quantization for large language models.
In International Conference on Machine Learning , pp. 38087–38099. PMLR, 2023a.
Xiao et al. (2023b)
Xiao, G., Tian, Y., Chen, B., Han, S., and Lewis, M.
Efficient streaming language models with attention sinks.
arXiv preprint arXiv:2309.17453 , 2023b.
Xu et al. (2024)
Xu, J., Liu, X., Wu, Y., Tong, Y., Li, Q., Ding, M., Tang, J., and Dong, Y.
Imagereward: Learning and evaluating human preferences for text-to-image generation.
Advances in Neural Information Processing Systems , 36, 2024.
Yang et al. (2024)
Yang, Z., Teng, J., Zheng, W., Ding, M., Huang, S., Xu, J., Yang, Y., Hong, W., Zhang, X., Feng, G., et al.
Cogvideox: Text-to-video diffusion models with an expert transformer.
arXiv preprint arXiv:2408.06072 , 2024.
Yu et al. (2022)
Yu, W., Luo, M., Zhou, P., Si, C., Zhou, Y., Wang, X., Feng, J., and Yan, S.
Metaformer is actually what you need for vision.
In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition , pp. 10819–10829, 2022.
Zhang et al. (2024)
Zhang, J., wei, J., Zhang, P., Zhu, J., and Chen, J.
Sageattention: Accurate 8-bit attention for plug-and-play inference acceleration, 2024.
URL https://arxiv.org/abs/2410.02367 .
Zhao et al. (2024)
Zhao, T., Fang, T., Liu, E., Rui, W., Soedarmadji, W., Li, S., Lin, Z., Dai, G., Yan, S., Yang, H., Ning, X., and Wang, Y.
Vidit-q: Efficient and accurate quantization of diffusion transformers for image and video generation, 2024.
Zheng et al. (2024)
Zheng, Z., Peng, X., Yang, T., Shen, C., Li, S., Liu, H., Zhou, Y., Li, T., and You, Y.
Open-sora: Democratizing efficient video production for all, March 2024.
URL https://github.com/hpcaitech/Open-Sora .