JaLMS
最新の AI 研究を日本語で解読

SageAttention2 Technical Report:
Accurate 4 Bit Attention for Plug-and-play Inference Acceleration

Jintao Zhang Tsinghua University Haofeng Huang Tsinghua University Pengle Zhang Tsinghua University Jia Wei Tsinghua University Jun Zhu Tsinghua University Jianfei Chen Tsinghua University
Abstract

線形層に対する量子化は広く使用されてきたが、注意機構の加速への応用は限定的であった。SageAttentionは8ビット行列乗算、16ビットアキュムレータを用いた16ビット行列乗算、および精度向上手法を活用し、FlashAttention2と比較して精度を保ちつつ2倍の高速化を実現するカーネルを実装している。注意機構の計算効率をさらに向上させつつ精度を維持するため、我々はSageAttention2を提案する。これは大幅に高速な4ビット行列乗算(Matmul)と追加の精度向上技術を活用している。第一に、我々はワープレベルの粒度で行列(Q,K)𝑄𝐾(Q,K)( italic_Q , italic_K )をINT4に量子化し、行列(P~,V)~𝑃𝑉(\widetilde{P},V)( over~ start_ARG italic_P end_ARG , italic_V )をFP8に量子化することを提案する。第二に、Q𝑄Qitalic_QV𝑉Vitalic_Vを平滑化する手法を提案し、INT4 QK𝑄𝐾QKitalic_Q italic_KとFP8 PV𝑃𝑉PVitalic_P italic_Vを用いた注意機構の精度を向上させる。第三に、タイムステップと層にわたる量子化精度を分析し、様々なモデルにおけるエンドツーエンドの評価指標を保証する適応的量子化手法を提案する。SageAttention2の1秒あたりの演算数(OPS)は、RTX4090上でFlashAttention2とxformersをそれぞれ約3倍および5倍上回る。包括的な実験により、我々のアプローチが大規模言語処理、画像生成、動画生成を含む多様なモデルにおいて、エンドツーエンドの評価指標の損失が無視できるほど小さいことが確認された。コードはhttps://github.com/thu-ml/SageAttentionで入手可能である

1 Introduction

アテンションの二次計算量の複雑さにより、実世界のアプリケーションにおいて配列が長くなるにつれて、効率的な実装がますます重要になっている (Jiang et al., 2024)。アテンションの計算需要を軽減するために、いくつかの戦略が開発されてきた。例えば、(1) 複雑さを O(N)𝑂𝑁O(N)italic_O ( italic_N ) に削減する線形アテンション手法 (Wang et al., 2020; Choromanski et al., 2020; Yu et al., 2022; Katharopoulos et al., 2020)、(2) コンテキストの一部を選択的に処理するスパースアテンション手法 (Liu et al., 2021; Chu et al., 2021; Li et al., ; Xiao et al., 2023b, 2024; Chen et al., 2023; Jiang et al., 2024; Venkataramanan et al., 2023; Gao et al., 2024; Fu et al., 2024) などがある。しかし、これらの手法は限られた範囲のモデルやタスクにのみ適している。広く使用されているアテンション手法は、ハードウェアの能力を活用して計算速度を向上させることでアテンションを最適化している。例えば、FlashAttention (Dao et al., 2022)、FlashAttention2 (Dao, 2023)、FlashAttention3 (Shah et al., 2024)、xformers (Lefaudeux et al., 2022)、SageAttention (Zhang et al., 2024) などがある。これらの研究は、コンテキストの一部に対するアテンションの計算を省略せず、様々なモデルやタスクにおいて印象的な速度と精度の性能を達成している。

注意機構における2つの行列乗算(Matmul)操作:QK𝑄superscript𝐾topQK^{\top}italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPTPV𝑃𝑉PVitalic_P italic_Vについて、SageAttentionはQK𝑄superscript𝐾topQK^{\top}italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPTをINT8に量子化し、PV𝑃𝑉PVitalic_P italic_VにはFP16アキュムレータを用いたFP16 Matmulを使用することで高速化を実現している。さらに、注意機構の精度を維持するため、SageAttentionはK𝐾Kitalic_Kのチャンネル単位の外れ値を除去することでスムージングを提案している。SageAttentionはFlashAttention2とxformersに比べて2×\times×および2.7×\times×の高速化を達成し、言語、画像、動画生成モデルにおいてエンドツーエンドの評価指標の損失が無視できるレベルである初めての量子化された注意機構となっている。しかしながら、SageAttentionには2つの弱点がある。(W1)INT8 MatmulはINT4の半分の速度しか達成できない。(W2)FP16アキュムレータを用いたFP16 MatmulはRTX4090とRTX3090 GPUにのみ対応している。 QK𝑄superscript𝐾topQK^{\top}italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPTにはより高速なINT4テンソルコアを活用し、PV𝑃𝑉PVitalic_P italic_Vを一般的に高速化できる手法を使用するため、我々はSageAttention2を提案する。これはQ,K𝑄𝐾Q,Kitalic_Q , italic_KをINT4に、P,V𝑃𝑉P,Vitalic_P , italic_VをFP8に量子化するものである。

課題Q,K𝑄𝐾Q,Kitalic_Q , italic_KをINT4に、P,V𝑃𝑉P,Vitalic_P , italic_VをFP8に量子化することは重大な課題を提示する。例えば、Q,K𝑄𝐾Q,Kitalic_Q , italic_KをINT4にテンソル単位で量子化するだけでも、テキストから動画を生成するモデルCogvideoXは完全にぼやけた動画を生成し(図2参照)、Llama2はMMULUにおいてランダム推測レベルの25%の精度しか達成できない。詳細に調査した結果、我々は3つの主要な課題を特定した:(C1) INT4の数値範囲は、量子化において通常77-7- 7から7777までの15の数字を含むが(Lin et al., 2024)、これはQ𝑄Qitalic_QK𝐾Kitalic_Kに異常値がある場合、重大な量子化誤差につながる。(C2) 一部のモデルの特定の層とタイムステップ(テキストから画像/動画の場合)において、Q𝑄Qitalic_QK𝐾Kitalic_KをINT4に、P𝑃Pitalic_PV𝑉Vitalic_VをFP8に量子化すると、注意計算に顕著な誤差が生じる。これらの最悪のケースの層/タイムステップにおける誤差は、エンドツーエンドの出力の精度に大きな影響を与える。 (C3) 我々は、テンソルコアにおけるFP8行列乗算用に設計されたFP32アキュムレータ(mma.f32.f8.f8.f32)が実際にはFP22、具体的には1符号ビット、8指数ビット、13仮数ビットであることを発見した。これによりPV𝑃𝑉PVitalic_P italic_Vの精度損失が生じる。

Refer to caption
図2:CogvideoXからQ、KをINT4に量子化した例。

我々のアプローチ。これらの課題に対処するため、我々は理由を詳細に分析し、2つの方法を提案する。第一に、行列Q𝑄Qitalic_QK𝐾Kitalic_Kにおけるチャンネル方向の顕著な外れ値に対して、SageAttentionにおいて平滑化K𝐾Kitalic_Kを採用し、Q𝑄Qitalic_Qにおけるこれらの外れ値を除去する効果的な方法を提案する。具体的には、Q𝑄Qitalic_Qのチャンネル次元の平均値を減算することを提案し、これをQmsubscript𝑄𝑚\overrightarrow{Q}_{m}over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPTと呼ぶ。その後、QmKsubscript𝑄𝑚𝐾\overrightarrow{Q}_{m}Kover→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT italic_KQK𝑄𝐾QKitalic_Q italic_K Matmulの後に加えることで、注意計算の正確性を確保する。第二に、特定の層とタイムステップが異なる入力間で一貫して量子化の課題を示すことを観察した。精度を維持するために、適応的な混合精度法を適用する。具体的には、これらの問題のある層とタイムステップに対しては8ビット(INT8+FP8)の注意を、その他に対しては4ビット(INT4+FP8)の注意を使用する。第三に、PV𝑃𝑉PVitalic_P italic_VのFP8 Matmulに22ビットのアキュムレータを使用することに関連する精度損失を軽減するために、V𝑉Vitalic_Vを平滑化して精度性能を向上させる方法を提案する。

性能。重要なことに、我々はRTX4090およびL20 GPU上でSageAttention2の高性能実装を提供する。この実装はRTX4090上でピーク性能485 TOPSを達成し、FlashAttention2とxformersをそれぞれ約3.1倍および5.4倍上回る。我々はSageAttention2を使用して、最先端のテキスト、画像、および動画生成モデルのエンドツーエンドメトリクスを広範に評価した。すべてのモデルとタスクにおいて、SageAttention2はモデル性能の無視できる程度の損失で直接プラグアンドプレイ方式で採用できる。

2 Preliminary

2.1 FlashAttention

自己注意の計算は以下のように定式化できる:S=QK/d,P=σ(S),O=PVformulae-sequence𝑆𝑄superscript𝐾top𝑑formulae-sequence𝑃𝜎𝑆𝑂𝑃𝑉S=QK^{\top}/\sqrt{d},~{}P=\sigma(S),~{}O=PVitalic_S = italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT / square-root start_ARG italic_d end_ARG , italic_P = italic_σ ( italic_S ) , italic_O = italic_P italic_V、ここでσ(S)ij=exp(Sij)/kexp(Sik)𝜎subscript𝑆𝑖𝑗subscript𝑆𝑖𝑗subscript𝑘subscript𝑆𝑖𝑘\sigma(S)_{ij}=\exp(S_{ij})/\sum_{k}\exp(S_{ik})italic_σ ( italic_S ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = roman_exp ( italic_S start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) / ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_exp ( italic_S start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT )はソフトマックス演算である。 行列Q𝑄Qitalic_QK𝐾Kitalic_K、およびV𝑉Vitalic_VはそれぞれN×d𝑁𝑑N\times ditalic_N × italic_dの次元を持ち、一方で行列S𝑆Sitalic_SP𝑃Pitalic_PN×N𝑁𝑁N\times Nitalic_N × italic_Nである。d𝑑ditalic_dは通常小さく、例えば64や128であるが、N𝑁Nitalic_Nは数千あるいは数百万にもなり得る。したがって、N×N𝑁𝑁N\times Nitalic_N × italic_N行列(S,P)𝑆𝑃(S,P)( italic_S , italic_P )(Q,K,V)𝑄𝐾𝑉(Q,K,V)( italic_Q , italic_K , italic_V )よりもはるかに大きく、素朴な実装では(S,P)𝑆𝑃(S,P)( italic_S , italic_P )の読み書きに膨大なグローバルメモリIOが必要となる。FlashAttention (Dao, 2023)は、Q𝑄Qitalic_QK𝐾Kitalic_K、およびV𝑉Vitalic_Vをトークン次元からブロックサイズbqsubscript𝑏𝑞b_{q}italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPTbkvsubscript𝑏𝑘𝑣b_{kv}italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPTbkvsubscript𝑏𝑘𝑣b_{kv}italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPTのブロック{Qi},{Ki},{Vi}subscript𝑄𝑖subscript𝐾𝑖subscript𝑉𝑖\{Q_{i}\},\{K_{i}\},\{V_{i}\}{ italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } , { italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } , { italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }にタイル化することを提案している。そして、(S,P)𝑆𝑃(S,P)( italic_S , italic_P )のメモリIOを避けるために、オンラインソフトマックス(Milakov & Gimelshein, 2018)を使用してO𝑂Oitalic_Oの各ブロック、つまりOisubscript𝑂𝑖O_{i}italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTを段階的に計算する:

まず、{Ki},{Vi}subscript𝐾𝑖subscript𝑉𝑖\{K_{i}\},\{V_{i}\}{ italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } , { italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }の各ブロックに対して、以下の方程式を反復的に計算する:

Sij=QiKj/d,(mij,P~ij)=σ~(mij1,Sij),lij=\displaystyle S^{j}_{i}=Q_{i}K_{j}^{\top}/\sqrt{d},~{}~{}(m^{j}_{i},\widetilde% {P}^{j}_{i})=\tilde{\sigma}(m^{j-1}_{i},S^{j}_{i}),~{}~{}l_{i}^{j}=italic_S start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT / square-root start_ARG italic_d end_ARG , ( italic_m start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , over~ start_ARG italic_P end_ARG start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = over~ start_ARG italic_σ end_ARG ( italic_m start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_S start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT = exp(mijmij1)lij1+rowsum(P~ij),subscriptsuperscript𝑚𝑗𝑖subscriptsuperscript𝑚𝑗1𝑖superscriptsubscript𝑙𝑖𝑗1rowsumsubscriptsuperscript~𝑃𝑗𝑖\displaystyle\exp(m^{j}_{i}-m^{j-1}_{i})l_{i}^{j-1}+\mathrm{rowsum}(\widetilde% {P}^{j}_{i}),roman_exp ( italic_m start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_m start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT + roman_rowsum ( over~ start_ARG italic_P end_ARG start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,
Oij=diag(exp(mijmij1))Oij1subscriptsuperscript𝑂𝑗𝑖diagsubscriptsuperscript𝑚𝑗𝑖subscriptsuperscript𝑚𝑗1𝑖subscriptsuperscript𝑂𝑗1𝑖\displaystyle O^{j}_{i}=\mathrm{diag}\left(\exp(m^{j}_{i}-m^{j-1}_{i})\right)O% ^{j-1}_{i}italic_O start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_diag ( roman_exp ( italic_m start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_m start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) italic_O start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT +P~ijVjsubscriptsuperscript~𝑃𝑗𝑖subscript𝑉𝑗\displaystyle+\widetilde{P}^{j}_{i}V_{j}+ over~ start_ARG italic_P end_ARG start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_V start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT

ここで、mijsuperscriptsubscript𝑚𝑖𝑗m_{i}^{j}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPTlijsuperscriptsubscript𝑙𝑖𝑗l_{i}^{j}italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPTbq×1subscript𝑏𝑞1b_{q}\times 1italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × 1ベクトルであり、それぞれ-\infty- ∞00で初期化される。σ~()~𝜎\tilde{\sigma}()over~ start_ARG italic_σ end_ARG ( )はオンラインソフトマックス演算子である:mij=max{mij1,rowmax(Sij)},P~ji=exp(Sijmij)formulae-sequencesubscriptsuperscript𝑚𝑗𝑖subscriptsuperscript𝑚𝑗1𝑖rowmaxsubscriptsuperscript𝑆𝑗𝑖superscriptsubscript~𝑃𝑗𝑖subscriptsuperscript𝑆𝑗𝑖subscriptsuperscript𝑚𝑗𝑖m^{j}_{i}=\max\{m^{j-1}_{i},\mathrm{rowmax}(S^{j}_{i})\},~{}\widetilde{P}_{j}^% {i}=\exp(S^{j}_{i}-m^{j}_{i})italic_m start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_max { italic_m start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , roman_rowmax ( italic_S start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } , over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = roman_exp ( italic_S start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_m start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )

最後に、出力Oisubscript𝑂𝑖O_{i}italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTOi=diag(lij)1Oijsubscript𝑂𝑖diagsuperscriptsuperscriptsubscript𝑙𝑖𝑗1superscriptsubscript𝑂𝑖𝑗O_{i}=\mathrm{diag}(l_{i}^{j})^{-1}O_{i}^{j}italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_diag ( italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPTによって計算できる。

2.2 Quantization

行列乗算 C=AB𝐶𝐴𝐵C=ABitalic_C = italic_A italic_B は、以下のように量子化によって加速できる:

(δA,A^)=ψ(A),(δB,B^)=ψ(B),C^=A^B^,formulae-sequencesubscript𝛿𝐴^𝐴𝜓𝐴formulae-sequencesubscript𝛿𝐵^𝐵𝜓𝐵^𝐶^𝐴^𝐵\displaystyle(\delta_{A},\hat{A})=\psi(A),~{}~{}(\delta_{B},\hat{B})=\psi(B),~% {}~{}\hat{C}=\hat{A}\hat{B},( italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT , over^ start_ARG italic_A end_ARG ) = italic_ψ ( italic_A ) , ( italic_δ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , over^ start_ARG italic_B end_ARG ) = italic_ψ ( italic_B ) , over^ start_ARG italic_C end_ARG = over^ start_ARG italic_A end_ARG over^ start_ARG italic_B end_ARG , C=ψδAδB1(C^)𝐶subscriptsuperscript𝜓1subscript𝛿𝐴subscript𝛿𝐵^𝐶\displaystyle~{}~{}C=\psi^{-1}_{\delta_{A}\delta_{B}}(\hat{C})italic_C = italic_ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_C end_ARG ) (1)

ψ𝜓\psiitalic_ψ量子化器であり、高精度(例:FP32またはFP16)の行列 A𝐴Aitalic_A を低精度形式 A^^𝐴\hat{A}over^ start_ARG italic_A end_ARG(例:INT4またはFP8)にスケール δAsubscript𝛿𝐴\delta_{A}italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT で変換し、ψ1superscript𝜓1\psi^{-1}italic_ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT は高精度に戻す逆量子化器である。我々は ψδA1(A^)Asubscriptsuperscript𝜓1subscript𝛿𝐴^𝐴𝐴\psi^{-1}_{\delta_{A}}(\hat{A})\approx Aitalic_ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_A end_ARG ) ≈ italic_A を持つべきである。実際の行列乗算 A^B^^𝐴^𝐵\hat{A}\hat{B}over^ start_ARG italic_A end_ARG over^ start_ARG italic_B end_ARG は低精度で行われる。現代のGPUでは、低精度の行列乗算は通常、高精度のものよりも数倍高速である。 多くの量子化器は、数値形式と粒度に依存する。例えば、共通のスケール因子を共有する要素の数などである。 例えば、INT4のテンソル単位の量子化器は、まずテンソル全体の絶対値の最大値としてスケールを計算し、要素をINT4の最大表現可能範囲[-7, +7]にスケーリングし、その後四捨五入してINT4にキャストする:A^=A/δA,δA=max(|A|)/7\hat{A}=\lceil A/\delta_{A}\rfloor,\delta_{A}=\max(|{A}|)/7over^ start_ARG italic_A end_ARG = ⌈ italic_A / italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ⌋ , italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT = roman_max ( | italic_A | ) / 7。 同様に、トークン単位の量子化器はテンソルの各トークンにスケール因子を割り当てる:A^[i,:]=A[i,:]/δA[i],δA[i]=max(|A[i,:]|)/7\hat{A}[i,:]=\lceil A[i,:]/\delta_{A}[i]\rfloor,\delta_{A}[i]=\max(|{A[i,:]}|)/7over^ start_ARG italic_A end_ARG [ italic_i , : ] = ⌈ italic_A [ italic_i , : ] / italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT [ italic_i ] ⌋ , italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT [ italic_i ] = roman_max ( | italic_A [ italic_i , : ] | ) / 7。 また、チャンネル単位の量子化器はテンソルの各チャンネルにスケール因子を割り当てる。つまり、チャンネル次元に沿って:A[:,i]=A[:,i]/δA[i],δA[i]=max(|A[:,i]|)/7A[:,i]=\lceil A[:,i]/\delta_{A}[i]\rfloor,\delta_{A}[i]=\max(|A[:,i]|)/{7}italic_A [ : , italic_i ] = ⌈ italic_A [ : , italic_i ] / italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT [ italic_i ] ⌋ , italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT [ italic_i ] = roman_max ( | italic_A [ : , italic_i ] | ) / 7。 逆量子化プロセスは要素ごとのスケーリングである:ψδAδB1(A^B^)=A^B^δAδBsuperscriptsubscript𝜓subscript𝛿𝐴subscript𝛿𝐵1^𝐴^𝐵^𝐴^𝐵subscript𝛿𝐴subscript𝛿𝐵\psi_{\delta_{A}\delta_{B}}^{-1}(\hat{A}\hat{B})=\hat{A}\hat{B}*\delta_{A}*% \delta_{B}italic_ψ start_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( over^ start_ARG italic_A end_ARG over^ start_ARG italic_B end_ARG ) = over^ start_ARG italic_A end_ARG over^ start_ARG italic_B end_ARG ∗ italic_δ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ∗ italic_δ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT

2.3 SageAttention

FlashAttention2 (Dao, 2023) のブロックタイリングに基づき、SageAttention (Zhang et al., 2024) はブロック単位の粒度で Q,K𝑄𝐾Q,Kitalic_Q , italic_K をINT8に量子化する。さらに、量子化の精度を保つために、SageAttentionは最初に K𝐾Kitalic_K を平滑化することを提案している:

K=K𝐾𝐾\displaystyle K=Kitalic_K = italic_K mean(K)mean𝐾\displaystyle-\text{mean}(K)- mean ( italic_K )
Q^i=Qi/δQ,δQ=max(|Qi|)/127,\displaystyle\hat{Q}_{i}=\lceil Q_{i}/\delta_{Q}\rfloor,~{}~{}\delta_{Q}=\max(% |{Q_{i}|})/127,over^ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ⌈ italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ⌋ , italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT = roman_max ( | italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | ) / 127 , K^j=Kj/δK,δK=max(|Kj|)/127\displaystyle~{}~{}\hat{K}_{j}=\lceil K_{j}/\delta_{K}\rfloor,~{}~{}\delta_{K}% =\max(|{K_{j}|})/127over^ start_ARG italic_K end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = ⌈ italic_K start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT / italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ⌋ , italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = roman_max ( | italic_K start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT | ) / 127
Sij=subscript𝑆𝑖𝑗absent\displaystyle S_{ij}=italic_S start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = QiKjδQδKsubscript𝑄𝑖superscriptsubscript𝐾𝑗topsubscript𝛿𝑄subscript𝛿𝐾\displaystyle Q_{i}K_{j}^{\top}*\delta_{Q}*\delta_{K}italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∗ italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∗ italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT

ここで、Qi,Kjsubscript𝑄𝑖subscript𝐾𝑗Q_{i},K_{j}italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_K start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT はFlashAttentionでタイル化されたブロックであり、mean(K)mean𝐾\text{mean}(K)mean ( italic_K )K𝐾Kitalic_K のチャンネル次元の平均値である。SageAttentionは P~V~𝑃𝑉\widetilde{P}Vover~ start_ARG italic_P end_ARG italic_V をFP16として保持し、P~V~𝑃𝑉\widetilde{P}Vover~ start_ARG italic_P end_ARG italic_V にはFP16アキュムレータを使用したFP16行列乗算を使用する。しかし、FP16アキュムレータを使用したFP16行列乗算は、RTX4090およびRTX3090 GPUでのみ高速化効果がある。

3 SageAttention-2

Refer to caption
図3: SageAttention2のワークフロー。 1 Q,K,Vを平滑化する。 2 GEMVを実行してΔSΔ𝑆\Delta Sroman_Δ italic_Sを得る。 3 ワープごとにQ,Kを量子化し、チャンネルごとにVを量子化する。 4 SageAttention2カーネルを実行する。 5 出力を修正する。

3.1 Formulation

2節で紹介したFlashAttentionと量子化に基づき、我々が開発した量子化注意機構アプローチについて説明する。

Quantization: (δQ,Q^,Qm)=ϕQ(Q),(δK,K^)=ϕK(K),(δV,V^,Vm)=ϕV(V)formulae-sequencesubscript𝛿𝑄^𝑄subscript𝑄𝑚subscriptitalic-ϕ𝑄𝑄formulae-sequencesubscript𝛿𝐾^𝐾subscriptitalic-ϕ𝐾𝐾subscript𝛿𝑉^𝑉subscript𝑉𝑚subscriptitalic-ϕ𝑉𝑉\displaystyle{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1% }{(\delta_{Q},\hat{Q},~{}~{}\overrightarrow{Q}_{m})}}={\color[rgb]{0,0,1}% \definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\phi_{Q}}}(Q),~{}~{}{\color[% rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{(\delta_{K},\hat{K}% )}}={\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\phi_{K% }}}(K),~{}~{}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1% }{(\delta_{V},\hat{V},\overrightarrow{V}_{m})}}={\color[rgb]{0,0,1}% \definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\phi_{V}}}(V)( italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , over^ start_ARG italic_Q end_ARG , over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) = italic_ϕ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q ) , ( italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , over^ start_ARG italic_K end_ARG ) = italic_ϕ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K ) , ( italic_δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT , over^ start_ARG italic_V end_ARG , over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) = italic_ϕ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V )
ΔS=QmKΔ𝑆subscript𝑄𝑚superscript𝐾top\displaystyle{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1% }{\Delta S=\overrightarrow{Q_{m}}K^{\top}}}roman_Δ italic_S = over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (2)
Attention: S=ψδQδK1(Q^K^)+ΔS,(m,P~)=σ~(m,S),(δP,P^)=ψP(P~)formulae-sequence𝑆subscriptsuperscript𝜓1subscript𝛿𝑄subscript𝛿𝐾^𝑄superscript^𝐾topΔ𝑆formulae-sequencesuperscript𝑚~𝑃~𝜎𝑚𝑆subscript𝛿𝑃^𝑃subscript𝜓𝑃~𝑃\displaystyle S={\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{% 0,0,1}{\psi^{-1}_{\delta_{Q}\delta_{K}}(\hat{Q}\hat{K}^{\top})+\Delta S}},~{}~% {}(m^{\prime},\widetilde{P})=\tilde{\sigma}(m,S),~{}~{}{\color[rgb]{0,0,1}% \definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{(\delta_{P},\hat{P})}}={\color% [rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\psi_{P}}}(% \widetilde{P})italic_S = italic_ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_Q end_ARG over^ start_ARG italic_K end_ARG start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) + roman_Δ italic_S , ( italic_m start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , over~ start_ARG italic_P end_ARG ) = over~ start_ARG italic_σ end_ARG ( italic_m , italic_S ) , ( italic_δ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT , over^ start_ARG italic_P end_ARG ) = italic_ψ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( over~ start_ARG italic_P end_ARG )
O=diag(exp(mm))O+ψδPδV1(P^V^)𝑂diagsuperscript𝑚𝑚𝑂subscriptsuperscript𝜓1subscript𝛿𝑃subscript𝛿𝑉^𝑃^𝑉\displaystyle O=\mathrm{diag}\left(\exp(m^{\prime}-m)\right)O+{\color[rgb]{% 0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\psi^{-1}_{\delta_{P}% \delta_{V}}(\hat{P}\hat{V})}}italic_O = roman_diag ( roman_exp ( italic_m start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - italic_m ) ) italic_O + italic_ψ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT italic_δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG italic_P end_ARG over^ start_ARG italic_V end_ARG ) (3)

ϕQsubscriptitalic-ϕ𝑄\phi_{Q}italic_ϕ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPTϕKsubscriptitalic-ϕ𝐾\phi_{K}italic_ϕ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPTϕVsubscriptitalic-ϕ𝑉\phi_{V}italic_ϕ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPTは、量子化されたQ𝑄Qitalic_QK𝐾Kitalic_KV𝑉Vitalic_Vを得るための3つの変換であり、これらについては後続のセクションで議論する。 簡略化のため、すべての上付き文字と下付き文字は省略するが、注意機構で使用される行列は依然としてタイルであり、計算は第2.1節で説明したFlashAttentionとして組織化されている。式23に示されている元の全精度版と比較して、SageAttention2Q,K,P,V𝑄𝐾𝑃𝑉Q,K,P,Vitalic_Q , italic_K , italic_P , italic_Vに量子化器を追加し、積に逆量子化器を追加することで、QK𝑄𝐾QKitalic_Q italic_KP~V~𝑃𝑉\widetilde{P}Vover~ start_ARG italic_P end_ARG italic_Vの両方の行列乗算を加速している。

Refer to caption
図4: 注意機構におけるテンソルのデータ分布の典型的な例。
表1: 異なる量子化手法のエンドツーエンドメトリクスの比較。Q,Kは4ビット整数に量子化され、P,Vは全精度のままである。
Q, K Smoothing (Q+K) Llama3.1 (Lambda) \uparrow Llama3.1 (WikeText) \downarrow CogVideo (vqa-a) \uparrow CogVideo (vqa-t) \uparrow
Full-Precision - 81.5% 6.013 77.605 75.360
INT4 Quantization 72.6% 11.698 27.114 24.670
80.8% 6.219 77.276 75.147
表2: 異なる量子化粒度を使用した場合の全層にわたる平均精度
Method Cos Sim \uparrow Relative L1 \downarrow RMSE \downarrow
Per-token 99.45% 0.0649 0.0335
Per-warp 99.45% 0.0648 0.0334
Per-block 98.03% 0.1492 0.0744
Per-tensor 97.15% 0.1800 0.0865
表3: 異なる量子化粒度を使用した場合の全層にわたる最悪精度
Method Cos Sim \uparrow Relative L1 \downarrow RMSE \downarrow
Per-token 96.76% 0.1916 0.0775
Per-warp 96.71% 0.1956 0.0779
Per-block 90.68% 0.3615 0.1490
Per-tensor 85.85% 0.4687 0.2261

3.2 Per-warp INT4 Quantization

SageAttentionはブロック単位の量子化を使用しており、GPUのストリーミングプロセッサごとにQ𝑄Qitalic_QK𝐾Kitalic_Kの各ブロックを量子化する。このような量子化戦略は、トークン単位の量子化に近い精度性能を達成し、量子化スケールベクトルδQsubscript𝛿𝑄\delta_{Q}italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPTδKsubscript𝛿𝐾\delta_{K}italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPTのドット積のオーバーヘッドを回避できる。しかし、Q𝑄Qitalic_QK𝐾Kitalic_KをINT4に量子化するには、より正確な量子化粒度が必要である。我々はワープ単位の量子化を提案する。これはブロック単位の量子化器よりも精密で粒度の細かい量子化アプローチであり、ベクトルのドット積による追加のオーバーヘッドもない。

具体的には、SageAttentionにおけるQ𝑄Qitalic_Qの各ブロックはcwsubscript𝑐𝑤c_{w}italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT個のセグメントに分割され、それぞれがGPUストリーミングプロセッサ(SM)内のcwsubscript𝑐𝑤c_{w}italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT個のGPUワープの1つによって処理される。その後、各ワープが割り当てられたセグメントに対してMatmulを実行する。ワープ単位のINT4量子化は、Q𝑄Qitalic_Qbq/wcsubscript𝑏𝑞𝑤𝑐b_{q}/{wc}italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT / italic_w italic_cトークンごとにスケール因子を割り当てる:

Q^[ibqcw\displaystyle\hat{Q}[\frac{i*b_{q}}{c_{w}}over^ start_ARG italic_Q end_ARG [ divide start_ARG italic_i ∗ italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG :bq(i+1)cw,:]=Q[ibqcw:bq(i+1)cw,:]δQ[i]\displaystyle:\frac{b_{q}*(i+1)}{c_{w}},:]=\left\lceil\frac{Q[\frac{i*b_{q}}{c% _{w}}:\frac{b_{q}*(i+1)}{c_{w}},:]}{\delta_{Q}[i]}\right\rfloor: divide start_ARG italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ∗ ( italic_i + 1 ) end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG , : ] = ⌈ divide start_ARG italic_Q [ divide start_ARG italic_i ∗ italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG : divide start_ARG italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ∗ ( italic_i + 1 ) end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG , : ] end_ARG start_ARG italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ italic_i ] end_ARG ⌋ (4)
δQ[i]=max(|Q[ibqcw:bq(i+1)cw,:]|)7\displaystyle\delta_{Q}[i]=\frac{\max(\left|\,Q[\frac{i*b_{q}}{c_{w}}:\frac{b_% {q}*(i+1)}{c_{w}},:]\,\right|)}{7}italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ italic_i ] = divide start_ARG roman_max ( | italic_Q [ divide start_ARG italic_i ∗ italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG : divide start_ARG italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ∗ ( italic_i + 1 ) end_ARG start_ARG italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT end_ARG , : ] | ) end_ARG start_ARG 7 end_ARG (5)

ワープ単位の量子化は、SageAttentionで使用されているブロック単位の量子化よりも優れた精度性能を提供する。これについては3.5節で議論する。

Refer to caption
図5: 平滑化Q𝑄Qitalic_Qの前後におけるQ𝑄Qitalic_Qの量子化値分布の例。

3.3 Smooth Q

INT4量子化の代表的な範囲は著しく制限されており、すなわち241=15superscript241152^{4}-1=152 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT - 1 = 15値のみです。この制限は、注意機構においてQ,K𝑄𝐾Q,Kitalic_Q , italic_KをINT4に量子化する際に性能を大幅に低下させます。例えば、Q,K𝑄𝐾Q,Kitalic_Q , italic_KをINT4で量子化すると、WikiTextにおけるLlama3.1のパープレキシティが90%以上増加し、CogvideoXが生成する動画の品質が約3倍低下します(表1参照)。我々は実際のモデルにおけるQ,K𝑄𝐾Q,Kitalic_Q , italic_Kのデータ分布を分析しました。例えば、Llama3.1CogvideoXではQ,K𝑄𝐾Q,Kitalic_Q , italic_Kにおいてチャンネル単位の顕著な外れ値が見られることがわかりました(図4参照)。チャンネルごとの量子化はこのような外れ値による量子化誤差を軽減できますが、この方法はQ,K𝑄𝐾Q,Kitalic_Q , italic_Kには適用できません。なぜなら、量子化はQK𝑄superscript𝐾topQK^{\top}italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPTの外れ値がもたらす量子化誤差を排除する方法を提案します:

Qm=mean(Q),γ(Q)=\displaystyle\overrightarrow{Q_{m}}=\mathrm{mean}(Q),~{}~{}\gamma(Q)=over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = roman_mean ( italic_Q ) , italic_γ ( italic_Q ) = QQm,γ(K)=Kmean(K),𝑄subscript𝑄𝑚𝛾𝐾𝐾mean𝐾\displaystyle Q-\overrightarrow{Q_{m}},~{}~{}\gamma(K)=K-\mathrm{mean}(K),italic_Q - over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG , italic_γ ( italic_K ) = italic_K - roman_mean ( italic_K ) ,
S=QK=𝑆𝑄𝐾absent\displaystyle S=QK=italic_S = italic_Q italic_K = γ(Q)γ(K)+Qmγ(K)𝛾𝑄𝛾𝐾subscript𝑄𝑚𝛾𝐾\displaystyle\gamma(Q)\gamma(K)+\overrightarrow{Q_{m}}\gamma(K)italic_γ ( italic_Q ) italic_γ ( italic_K ) + over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG italic_γ ( italic_K )
ϕQ(Q)=ψQγ(Q),Qmsubscriptitalic-ϕ𝑄𝑄subscript𝜓𝑄𝛾𝑄subscript𝑄𝑚\displaystyle\phi_{Q}(Q)=\psi_{Q}\circ\gamma(Q),~{}\overrightarrow{Q_{m}}italic_ϕ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q ) = italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∘ italic_γ ( italic_Q ) , over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG ,ϕK(K)=ψKγ(K)\displaystyle,~{}~{}~{}~{}~{}\phi_{K}(K)=\psi_{K}\circ\gamma(K), italic_ϕ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K ) = italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∘ italic_γ ( italic_K ) (6)

ここで、mean(K)mean𝐾\mathrm{mean}(K)roman_mean ( italic_K )γ(K)𝛾𝐾\gamma(K)italic_γ ( italic_K )はSageAttention (Zhang et al., 2024)で議論されており、K𝐾Kitalic_Kの変換は注意スコアP~~𝑃\widetilde{P}over~ start_ARG italic_P end_ARGを変更しません。mean(Q)=1Nt=1NQ[t,:]mean𝑄1𝑁superscriptsubscript𝑡1𝑁𝑄𝑡:\mathrm{mean}(Q)=\frac{1}{N}\sum_{t=1}^{N}Q[t,:]roman_mean ( italic_Q ) = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_Q [ italic_t , : ]は形状1×d1𝑑1\times d1 × italic_dのベクトルです。Q𝑄Qitalic_Qの変換については、Qmsubscript𝑄𝑚\overrightarrow{Q_{m}}over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARGKsuperscript𝐾topK^{\top}italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPTの間でGEMV(一般行列ベクトル乗算)を実行します。つまり、QmKsubscript𝑄𝑚superscript𝐾top\overrightarrow{Q_{m}}K^{\top}over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPTΔSΔ𝑆\Delta Sroman_Δ italic_Sとして計算します。このΔSΔ𝑆\Delta Sroman_Δ italic_Sは注意の計算中にS𝑆Sitalic_Sに加算されます。 最終的に、全精度のQ,K𝑄𝐾Q,Kitalic_Q , italic_Kから量子化されたQ^,K^^𝑄^𝐾\hat{Q},\hat{K}over^ start_ARG italic_Q end_ARG , over^ start_ARG italic_K end_ARGへの変換は式6のように表現できます。ここで、ψQsubscript𝜓𝑄\psi_{Q}italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPTψKsubscript𝜓𝐾\psi_{K}italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPTQ𝑄Qitalic_QK𝐾Kitalic_Kのための2つの量子化器です。言い換えれば、全精度のQ,K𝑄𝐾Q,Kitalic_Q , italic_Kはチャンネル次元の平均値を減算してから量子化されます。

5は、CogvideoXから平滑化Qの有無による量子化されたQ𝑄Qitalic_Qの分布の例を示しています。平滑化Qを使用することで、INT4の範囲がより均一かつ完全に利用されていることがわかります。 表1は、Llama3.1 (Dubey et al., 2024)CogvideoX (Yang et al., 2024)における平滑化Q+Kの有無による異なる量子化方法のエンドツーエンドの指標を示しています。結果は平滑化Q+Kが精度に大きな利点をもたらすことを示しています。また、表10と表10は、効果の順序が平滑化Q+K >>> 平滑化Q >>> 平滑化K >>> その他のベースラインであることを示しています。

アルゴリズム1 SageAttention2の実装。
入力: 行列 Q(FP16),K(FP16),V(FP16)N×d𝑄FP16𝐾FP16𝑉FP16superscript𝑁𝑑Q(\text{FP16}),K(\text{FP16}),V(\text{FP16})\in\mathbb{R}^{N\times d}italic_Q ( FP16 ) , italic_K ( FP16 ) , italic_V ( FP16 ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_d end_POSTSUPERSCRIPT、ブロックサイズ bq,bkvsubscript𝑏𝑞subscript𝑏𝑘𝑣b_{q},b_{kv}italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT、ワープ数 cwsubscript𝑐𝑤c_{w}italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT
前処理: K=Kmean(K)𝐾𝐾mean𝐾K=K-\mathrm{mean}(K)italic_K = italic_K - roman_mean ( italic_K ), Qm=mean(Q)subscript𝑄𝑚mean𝑄\overrightarrow{Q}_{m}=\mathrm{mean}(Q)over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = roman_mean ( italic_Q ), Q=QQm𝑄𝑄subscript𝑄𝑚Q=Q-\overrightarrow{Q}_{m}italic_Q = italic_Q - over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT, ΔS=GEMV(Qm,K)Δ𝑆GEMVsubscript𝑄𝑚superscript𝐾top\Delta S=\text{GEMV}(\overrightarrow{Q}_{m},K^{\top})roman_Δ italic_S = GEMV ( over→ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ), Vm=mean(V),V=VVmformulae-sequencesubscript𝑉𝑚mean𝑉𝑉𝑉subscript𝑉𝑚\overrightarrow{V}_{m}=\mathrm{mean}(V),V=V-\overrightarrow{V}_{m}over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = roman_mean ( italic_V ) , italic_V = italic_V - over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT
量子化: (δQ,Q^)=ψQ(Q)subscript𝛿𝑄^𝑄subscript𝜓𝑄𝑄{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{(\delta_{Q}% ,\hat{Q})}}={\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}% {\psi_{Q}}}(Q)( italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT , over^ start_ARG italic_Q end_ARG ) = italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q ) // INT4 ワープごと。 (δK,K^)=ψK(K)subscript𝛿𝐾^𝐾subscript𝜓𝐾𝐾~{}~{}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{(% \delta_{K},\hat{K})}}={\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{% rgb}{0,0,1}{\psi_{K}}}(K)( italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , over^ start_ARG italic_K end_ARG ) = italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K )(δV,V^)=ψV(V)subscript𝛿𝑉^𝑉subscript𝜓𝑉𝑉~{}~{}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{(% \delta_{V},\hat{V})}}={\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{% rgb}{0,0,1}{\psi_{V}}}(V)( italic_δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT , over^ start_ARG italic_V end_ARG ) = italic_ψ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V )// FP8 チャンネルごと。
Q^^𝑄\hat{Q}over^ start_ARG italic_Q end_ARGTm=N/bqsubscript𝑇𝑚𝑁subscript𝑏𝑞T_{m}={N}/{b_{q}}italic_T start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_N / italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPTブロック{Q^i}subscript^𝑄𝑖\{\hat{Q}_{i}\}{ over^ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }に分割する; K^^𝐾\hat{K}over^ start_ARG italic_K end_ARG, V^^𝑉\hat{V}over^ start_ARG italic_V end_ARG, および ΔSΔ𝑆\Delta Sroman_Δ italic_STn=N/bkvsubscript𝑇𝑛𝑁subscript𝑏𝑘𝑣T_{n}={N}/{b_{kv}}italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = italic_N / italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPTブロック{K^i}subscript^𝐾𝑖\{\hat{K}_{i}\}{ over^ start_ARG italic_K end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }, {V^i}subscript^𝑉𝑖\{\hat{V}_{i}\}{ over^ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }, および {ΔSi}Δsubscript𝑆𝑖\{\Delta S_{i}\}{ roman_Δ italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }に分割する;
for i=1𝑖1i=1italic_i = 1 to Tmsubscript𝑇𝑚T_{m}italic_T start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT do
Q^isubscript^𝑄𝑖\hat{Q}_{i}over^ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTδQ[icw:(i+1)cw]\delta_{Q}[i*c_{w}:(i+1)*c_{w}]italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ italic_i ∗ italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT : ( italic_i + 1 ) ∗ italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ]をSMにロードする;
for j in [1, Tnsubscript𝑇𝑛T_{n}italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT] do
K^jsubscript^𝐾𝑗\hat{K}_{j}over^ start_ARG italic_K end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, V^jsubscript^𝑉𝑗\hat{V}_{j}over^ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, および δK[j]subscript𝛿𝐾delimited-[]𝑗\delta_{K}[j]italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT [ italic_j ]をSMにロードする;
w=range(cw),st=wcwformulae-sequence𝑤rangesubscript𝑐𝑤𝑠𝑡𝑤subscript𝑐𝑤w=\text{range}(c_{w}),st=w*c_{w}italic_w = range ( italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ) , italic_s italic_t = italic_w ∗ italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT
Sij[st:st+cw]=Matmul(Q^i[st:st+cw],K^j)×δQ[st+w]×δK[j]S_{i}^{j}[st:st+c_{w}]=\mathrm{Matmul}(\hat{Q}_{i}[st:st+c_{w}],\hat{K}_{j}^{% \top})\times\delta_{Q}[st+w]\times\delta_{K}[j]italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT [ italic_s italic_t : italic_s italic_t + italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ] = roman_Matmul ( over^ start_ARG italic_Q end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ italic_s italic_t : italic_s italic_t + italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT ] , over^ start_ARG italic_K end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) × italic_δ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT [ italic_s italic_t + italic_w ] × italic_δ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT [ italic_j ] + ΔSjΔsubscript𝑆𝑗\Delta S_{j}roman_Δ italic_S start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT; // cwsubscript𝑐𝑤c_{w}italic_c start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPTワープで並列化。
mij=max(mij1,rowmax(Sij))superscriptsubscript𝑚𝑖𝑗maxsuperscriptsubscript𝑚𝑖𝑗1rowmaxsuperscriptsubscript𝑆𝑖𝑗m_{i}^{j}=\mathrm{max}(m_{i}^{j-1},\mathrm{rowmax}(S_{i}^{j}))italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT = roman_max ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT , roman_rowmax ( italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ), P~ij=exp(Sijmij)superscriptsubscript~𝑃𝑖𝑗expsuperscriptsubscript𝑆𝑖𝑗superscriptsubscript𝑚𝑖𝑗\widetilde{P}_{i}^{j}=\mathrm{exp}(S_{i}^{j}-m_{i}^{j})over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT = roman_exp ( italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ), lij=emij1mij+rowsum(P~ij)superscriptsubscript𝑙𝑖𝑗superscript𝑒superscriptsubscript𝑚𝑖𝑗1superscriptsubscript𝑚𝑖𝑗rowsumsuperscriptsubscript~𝑃𝑖𝑗l_{i}^{j}=e^{m_{i}^{j-1}-m_{i}^{j}}+\mathrm{rowsum}(\widetilde{P}_{i}^{j})italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT = italic_e start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT + roman_rowsum ( over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ;
Oij=diag(emij1mij)1Oij1+superscriptsubscript𝑂𝑖𝑗limit-fromdiagsuperscriptsuperscript𝑒superscriptsubscript𝑚𝑖𝑗1superscriptsubscript𝑚𝑖𝑗1superscriptsubscript𝑂𝑖𝑗1O_{i}^{j}=\mathrm{diag}(e^{m_{i}^{j-1}-m_{i}^{j}})^{-1}O_{i}^{j-1}+italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT = roman_diag ( italic_e start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT + Matmul((P~ij448).to(FP8.e4m3),Vj)Matmulsuperscriptsubscript~𝑃𝑖𝑗448.to(FP8.e4m3)subscript𝑉𝑗\mathrm{Matmul}((\widetilde{P}_{i}^{j}*448)\text{.to(FP8.e4m3)},V_{j})roman_Matmul ( ( over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ∗ 448 ) .to(FP8.e4m3) , italic_V start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ;
end for
δVsubscript𝛿𝑉\delta_{V}italic_δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPTをSMにロードする;
Oi=diag(liTn)1OiTn/448δVsubscript𝑂𝑖diagsuperscriptsuperscriptsubscript𝑙𝑖subscript𝑇𝑛1superscriptsubscript𝑂𝑖subscript𝑇𝑛448subscript𝛿𝑉O_{i}=\mathrm{diag}(l_{i}^{T_{n}})^{-1}O_{i}^{T_{n}}/448*\delta_{V}italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_diag ( italic_l start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT / 448 ∗ italic_δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ;
Oisubscript𝑂𝑖O_{i}italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTを書き込む ;
end for
O={Oi}𝑂subscript𝑂𝑖O=\{O_{i}\}italic_O = { italic_O start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }, O=O+Vm𝑂𝑂subscript𝑉𝑚O=O+\overrightarrow{V}_{m}italic_O = italic_O + over→ start_ARG italic_V end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT
return O𝑂Oitalic_O
表4: 平均精度 CogvideoXモデルの全層にわたる(P~,V)~𝑃𝑉(\widetilde{P},V)( over~ start_ARG italic_P end_ARG , italic_V )の異なるデータ型を使用した場合。ここで(Q,K)𝑄𝐾(Q,K)( italic_Q , italic_K )は平滑化されている。
Q,K𝑄𝐾Q,Kitalic_Q , italic_K P~,V~𝑃𝑉\widetilde{P},Vover~ start_ARG italic_P end_ARG , italic_V Cos Sim \uparrow Relative L1 \downarrow RMSE \downarrow
INT4 INT8 77.05% 0.5618 0.5044
E5M2 99.20% 0.0905 0.0903
E4M3 99.44% 0.0683 0.0347
FP16 99.45% 0.0649 0.0335
表5: 最悪精度 CogvideoXモデルの全層にわたる(P~,V)~𝑃𝑉(\widetilde{P},V)( over~ start_ARG italic_P end_ARG , italic_V )の異なるデータ型を使用した場合。ここで(Q,K)𝑄𝐾(Q,K)( italic_Q , italic_K )は平滑化されている。
Q,K𝑄𝐾Q,Kitalic_Q , italic_K P~,V~𝑃𝑉\widetilde{P},Vover~ start_ARG italic_P end_ARG , italic_V Cos Sim \uparrow Relative L1 \downarrow RMSE \downarrow
INT4 INT8 19.52% 0.9579 1.4483
E5M2 94.94% 0.2327 0.2361
E4M3 96.70% 0.1956 0.0779
FP16 96.76% 0.1916 0.0775
表6: CogvideoXモデルの実際のテンソルにおけるV𝑉Vitalic_Vの平滑化の有無による精度の例。
Smooth V Cos Sim \uparrow Relative L1 \downarrow RMSE \downarrow
98.25% 0.1980 0.2387
99.75% 0.0406 0.0773
Refer to caption
図6: FP22データ型で表現されたP~~𝑃\widetilde{P}over~ start_ARG italic_P end_ARGの1行とV𝑉Vitalic_Vの1列の内積精度の例。
表7: mma(f8f8f32)のFP8行列乗算命令の誤差。
Precision of Accumulated Value E8M13 E8M23
Error compared to FP32 0 ||||mma(f32.f16.f16.f32) - mma(f32.f8.f8.f32)||||

3.4 Smooth V

SageAttentionの欠点である、PV𝑃𝑉PVitalic_P italic_Vに対するFP16累算器を持つFP16 MatmulがRTX4090のようなGPUでのみ効果的であることを避けるため、我々はP𝑃Pitalic_PV𝑉Vitalic_VをFP8に量子化してFP8テンソルコアの普遍的な加速を活用するアプローチを採用した。しかし、我々はAda アーキテクチャ上のmma(f32f8f8f32)命令の累算器が実際にはFP22、具体的には1符号ビット、8指数ビット、13仮数ビットであることを発見した。具体的には、mma(f32f8f8f32)命令C=AB+D𝐶𝐴𝐵𝐷C=AB+Ditalic_C = italic_A italic_B + italic_Dにおいて、A,B𝐴𝐵A,Bitalic_A , italic_BはFP8データ型のテンソルで、C,D𝐶𝐷C,Ditalic_C , italic_DはFP32データ型のテンソルである。我々はA,B𝐴𝐵A,Bitalic_A , italic_Bをゼロに初期化し、D𝐷Ditalic_Dを変化させて累算器のデータ型をテストした。表7に示すように、D𝐷Ditalic_Dが1符号ビット、8指数ビット、13仮数ビットで初期化されると、C𝐶Citalic_Cの値はmma(f16f16f32)命令の結果と一致する。しかし、D𝐷Ditalic_Dが13ビット以上の仮数ビットで初期化されると、C𝐶Citalic_Cの誤差はmma(f32f16f16f32)mma(f32f8f8f32)の結果の差に対応する。 したがって、FP8に量子化された行列P~~𝑃\widetilde{P}over~ start_ARG italic_P end_ARGV𝑉Vitalic_Vの行列乗算は、FP32累算器を使用する場合と比較して、ある程度の精度損失を被る。この精度損失を可能な限り軽減するために、我々はV𝑉Vitalic_Vを滑らかにすることを提案する:

γ(V)=Vmean(V),𝛾𝑉𝑉mean𝑉\displaystyle\gamma(V)=V-\mathrm{mean}(V),italic_γ ( italic_V ) = italic_V - roman_mean ( italic_V ) , mean(V)=Vmmean𝑉subscript𝑉𝑚\displaystyle~{}~{}~{}~{}\mathrm{mean}(V)=\overrightarrow{V_{m}}roman_mean ( italic_V ) = over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG
ϕV(V)=subscriptitalic-ϕ𝑉𝑉absent\displaystyle\phi_{V}(V)=italic_ϕ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V ) = ψVγ,Vmsubscript𝜓𝑉𝛾subscript𝑉𝑚\displaystyle\psi_{V}\circ\gamma,~{}\overrightarrow{V_{m}}italic_ψ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∘ italic_γ , over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG (7)

6に示すように、この戦略は以下の理由によりP~V~𝑃𝑉\widetilde{P}Vover~ start_ARG italic_P end_ARG italic_Vの値に対するFP22の精度を向上させる:P~~𝑃\widetilde{P}over~ start_ARG italic_P end_ARGの各行は0から1の値範囲にわたり、V𝑉Vitalic_Vの各列は一貫してチャンネル方向の外れ値を持ち、それらは排他的に正または負で、通常8から9の範囲にある。結果として、P~V~𝑃𝑉\widetilde{P}Vover~ start_ARG italic_P end_ARG italic_Vの値はかなり大きくなる可能性がある。しかし、浮動小数点数の表現範囲は一様ではなく、ゼロ付近でより密である。したがって、チャンネル次元に沿って平均をV𝑉Vitalic_Vから減算することで、P~V~𝑃𝑉\widetilde{P}Vover~ start_ARG italic_P end_ARG italic_Vの値はゼロに近くなり、より高い表現精度が得られる。我々は、このような説明と戦略がコミュニティによって報告された多くの問題111https://github.com/triton-lang/triton/issues/4476, https://github.com/triton-lang/triton/issues/5065を解決できると考えている。 さらに、アテンション計算の正確性を維持するためには、Vmsubscript𝑉𝑚\overrightarrow{V_{m}}over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARGO𝑂Oitalic_Oの最終計算に加えるだけで十分である:O=O+Vm𝑂𝑂subscript𝑉𝑚O=O+\overrightarrow{V_{m}}italic_O = italic_O + over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG。これは、P~~𝑃\widetilde{P}over~ start_ARG italic_P end_ARG行列の各行の合計が1に等しいため、P~Vm=Vm~𝑃subscript𝑉𝑚subscript𝑉𝑚\widetilde{P}\overrightarrow{V_{m}}=\overrightarrow{V_{m}}over~ start_ARG italic_P end_ARG over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG = over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARGとなるからである。 言い換えれば、この方法はV𝑉Vitalic_Vを2つの部分に分解する:Vmsubscript𝑉𝑚\overrightarrow{V_{m}}over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARGV𝑉Vitalic_VV𝑉Vitalic_Vについては、各列の値をゼロの周りに中心化し、これにより量子化されたP~~𝑃\widetilde{P}over~ start_ARG italic_P end_ARG行列の行と量子化されたV𝑉Vitalic_V行列の列との内積結果がゼロに近くなる。これによりFP22の表現がより正確になる。一方、Vmsubscript𝑉𝑚\overrightarrow{V_{m}}over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARGはFP16で保持され、最後にO𝑂Oitalic_Oに加算されるため、Vmsubscript𝑉𝑚\overrightarrow{V_{m}}over→ start_ARG italic_V start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG部分の精度損失を引き起こさない。

6は、CogvideoXからサンプリングされた実際のテンソルに対する、V𝑉Vitalic_Vの平滑化の有無によるアテンション精度を示している。これは、Q,K𝑄𝐾Q,Kitalic_Q , italic_KをINT4に、P~,V~𝑃𝑉\widetilde{P},Vover~ start_ARG italic_P end_ARG , italic_VをFP8に量子化する際に、Vを平滑化することでSageAttention2の精度を向上させることができることを示している。

3.5 Quantization for Q, K, P, V

Q,K𝑄𝐾Q,Kitalic_Q , italic_Kの量子化。我々は、ワープ単位でのψQ(Q)subscript𝜓𝑄𝑄\psi_{Q}(Q)italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q )ψK(K)subscript𝜓𝐾𝐾\psi_{K}(K)italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K )を提案する。これは第一に、QK𝑄superscript𝐾topQK^{\top}italic_Q italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPTの内部軸のスケール係数を非量子化に使用できないため、チャンネル単位の量子化が実現不可能であるためである(Xiao et al., 2023a)。第二に、表3と表3に示すように、我々はCogvideoXのすべての層にわたる実際のQ,K,V𝑄𝐾𝑉Q,K,Vitalic_Q , italic_K , italic_Vを使用して、トークン単位、ワープ単位、ブロック単位、テンソル単位の粒度でのINT4量子化の平均精度と最悪ケースの精度を比較した。結果は、ワープ単位の量子化の精度がトークン単位に非常に近く、ブロック単位やテンソル単位よりもはるかに優れていることを示している。さらに、セクション3.2で議論したように、ワープ単位の量子化はトークン単位よりも非量子化のオーバーヘッドが少ない。

P~,V~𝑃𝑉\widetilde{P},Vover~ start_ARG italic_P end_ARG , italic_Vの量子化。我々はψP(P~)subscript𝜓𝑃~𝑃\psi_{P}(\widetilde{P})italic_ψ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( over~ start_ARG italic_P end_ARG )ψV(V)subscript𝜓𝑉𝑉\psi_{V}(V)italic_ψ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V )にFP8、具体的にはE4M3データ型を選択した。これには二つの理由がある。第一に、ほとんどのGPUがFP8行列乗算演算をサポートするテンソルコアを持っており、これはFP16を使用する場合の2倍の速度である。第二に、表5と表5は、CogvideoXのすべての層にわたる実際のQ,K,V𝑄𝐾𝑉Q,K,Vitalic_Q , italic_K , italic_Vを使用して、P~,V~𝑃𝑉\widetilde{P},Vover~ start_ARG italic_P end_ARG , italic_Vに使用される異なるデータ型の平均精度と最悪精度を示している。E4M3を使用する精度がFP16を使用する場合に非常に近く、E5M2やINT8よりも優れていることがわかる。 我々は、ψP(P~)subscript𝜓𝑃~𝑃\psi_{P}(\widetilde{P})italic_ψ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( over~ start_ARG italic_P end_ARG )をブロック単位で、ψV(V)subscript𝜓𝑉𝑉\psi_{V}(V)italic_ψ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V )をチャンネル単位で使用することを提案する。これには三つの理由がある。第一に、P~~𝑃\widetilde{P}over~ start_ARG italic_P end_ARGのチャンネル単位量子化とV𝑉Vitalic_Vのトークン単位量子化は、非量子化に外部軸のスケール係数が必要なため実現不可能である。第二に、P~=exp(Sirowmax(Si))~𝑃expsubscript𝑆𝑖rowmaxsubscript𝑆𝑖\widetilde{P}=\mathrm{exp}(S_{i}-\mathrm{rowmax}(S_{i}))over~ start_ARG italic_P end_ARG = roman_exp ( italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - roman_rowmax ( italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )であり、ここでSisubscript𝑆𝑖S_{i}italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTQ𝑄Qitalic_QKsuperscript𝐾topK^{\top}italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPTのブロックの行列乗算結果であり、P~~𝑃\widetilde{P}over~ start_ARG italic_P end_ARGの各行の最大値は1である。したがって、我々はブロックP~~𝑃\widetilde{P}over~ start_ARG italic_P end_ARGに単一の静的スケールs=1448𝑠1448s=\frac{1}{448}italic_s = divide start_ARG 1 end_ARG start_ARG 448 end_ARGを割り当てることができ、その精度はトークン単位の量子化と等しい。第三に、チャンネル単位の量子化はV𝑉Vitalic_Vのチャンネル単位の外れ値に対処できる。

精度指標。我々は、量子化された注意出力Osuperscript𝑂O^{\prime}italic_O start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPTを全精度の注意出力O𝑂Oitalic_Oと比較するために3つの指標を使用する:まず、Osuperscript𝑂O^{\prime}italic_O start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPTO𝑂Oitalic_O1×n1𝑛1\times n1 × italic_nの形状のベクトルに平坦化する。次に、コサイン類似度:CosSim=OO/O2O2𝐶𝑜𝑠𝑆𝑖𝑚𝑂superscript𝑂superscript𝑂2superscript𝑂2CosSim=\sum OO^{\prime}/\smash{\sqrt{\sum O^{2}}}\smash{\sqrt{\sum O^{\prime 2% }}}italic_C italic_o italic_s italic_S italic_i italic_m = ∑ italic_O italic_O start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT / square-root start_ARG ∑ italic_O start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG square-root start_ARG ∑ italic_O start_POSTSUPERSCRIPT ′ 2 end_POSTSUPERSCRIPT end_ARG、相対L1距離:L1=|OO|/|O|𝐿1𝑂superscript𝑂𝑂L1=\sum|O-O^{\prime}|/\sum|O|italic_L 1 = ∑ | italic_O - italic_O start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | / ∑ | italic_O |、二乗平均平方根誤差:RMSE=(1/n)(OO)2𝑅𝑀𝑆𝐸1𝑛superscript𝑂superscript𝑂2RMSE=\sqrt{(1/n)\sum(O-O^{\prime})^{2}}italic_R italic_M italic_S italic_E = square-root start_ARG ( 1 / italic_n ) ∑ ( italic_O - italic_O start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

表8: SageAttention2の2つのカーネル実装。
     Kernel      ψQ(Q)subscript𝜓𝑄𝑄\psi_{Q}(Q)italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q ), ψK(K)subscript𝜓𝐾𝐾\psi_{K}(K)italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K )      ψP(P~)subscript𝜓𝑃~𝑃\psi_{P}(\widetilde{P})italic_ψ start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT ( over~ start_ARG italic_P end_ARG ), ψV(V)subscript𝜓𝑉𝑉\psi_{V}(V)italic_ψ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_V )
     SageAttn2-4b      INT4, per-warp      FP8, per-block;  FP8, per-channel
     SageAttn2-8b      INT8, per-warp      FP8, per-block;  FP8, per-channel
Refer to caption
図7: Llama3.1CogvideoXの異なる入力に対する、異なる層とタイムステップにおけるSageAttn-4bcossim(1L1)𝑐𝑜𝑠𝑠𝑖𝑚1𝐿1cossim*(1-L1)italic_c italic_o italic_s italic_s italic_i italic_m ∗ ( 1 - italic_L 1 )の平均と標準偏差。

3.6 Adaptive Quantization over Layer and Timestep

セクション3.5の議論に基づき、我々はψQ(Q)subscript𝜓𝑄𝑄\psi_{Q}(Q)italic_ψ start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ( italic_Q )ψK(K)subscript𝜓𝐾𝐾\psi_{K}(K)italic_ψ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_K )にINT4またはINT8量子化を使用する選択に基づいて2つのアテンションカーネル(表8参照)を実装する。 これらのカーネルの速度は(SageAttn2-4b>>>SageAttn2-8b)の順であるが、精度の順序は逆である。 SageAttn2-4bは多くのタスクで十分に機能する可能性があり、例えばLlama3.1のLambdaベンチマークでのエンドツーエンドの指標損失はごくわずかであるが、SageAttn2-4bは一部の困難なシナリオでは不十分である。したがって、我々はまずSageAttn2-4bSageAttn2-8bを使用して様々なモデルの精度性能の詳細な分析を行う。その後、効果的で使いやすい適応的手法を提案する。

レイヤーとタイムステップにわたる精度。7は、Llama3.1CogvideoXの異なる入力の異なるレイヤーとタイムステップにわたるSageAttn-4bの精度の平均と標準偏差を示している。これはCosSim(1L1)𝐶𝑜𝑠𝑆𝑖𝑚1𝐿1CosSim*(1-L1)italic_C italic_o italic_s italic_S italic_i italic_m ∗ ( 1 - italic_L 1 )を用いて計算されている。特定のレイヤーとタイムステップが特定の範囲の誤差を示すことが観察できる。

我々の手法。 我々の適応的戦略は以下の通りである:まず、複数のランダムな入力を使用して、モデル内の各(レイヤー、タイムステップ)の組み合わせに対して、SageAttn2-4bの平均精度をフル精度のアテンションと比較して評価する。タイムステップの次元がない言語モデルなどのモデルの場合、評価はレイヤーの次元に限定される。次に、(レイヤー、タイムステップ)の組み合わせをCosSim(1L1)𝐶𝑜𝑠𝑆𝑖𝑚1𝐿1CosSim*(1-L1)italic_C italic_o italic_s italic_S italic_i italic_m ∗ ( 1 - italic_L 1 )値に基づいて降順にソートする。そして、これらの組み合わせの最小ξ𝜉\xiitalic_ξ%を特定する。最後に、特定された(レイヤー、タイムステップ)の組み合わせに対して、他のすべてのプロンプトにわたって一貫してSageAttn2-8bを適用する。我々はこのような戦略をSageAttn-mixと呼び、各モデルに使用したξ𝜉\xiitalic_ξ%の値をセクション4.1で報告する。

Refer to caption
図8: SageAttention2とベースラインの速度比較(RTX4090、headdim=64)。
Refer to caption
図9: SageAttention2とベースラインの速度比較(RTX4090、headdim=128)。
Refer to caption
図10: SageAttention2とベースラインの速度比較(RTX4090、headdim=256)。
Refer to caption
図11: SageAttention2とベースラインの速度比較(L20、headdim=64)。
Refer to caption
図12: SageAttention2とベースラインの速度比較(L20、headdim=128)。
Refer to caption
図13: SageAttention2とベースラインの速度比較(L20、headdim=256)。

4 Experiment

4.1 Setup

モデル。 我々は、言語、画像、動画生成の代表的なモデルの多様なセットにわたってSageAttention2の有効性を検証する。 具体的には、5つのモデルで実験を行う:テキスト生成にはLlama2 (7B) (Touvron et al., 2023)Llama3.1 (8B) (Dubey et al., 2024)GLM4 (9B) (GLM et al., 2024)を、テキストから動画生成にはCogvideoX (2B) (Yang et al., 2024)Open-Sora (Zheng et al., 2024)を、テキストから画像生成にはFlux (schnell) (Black Forest Labs, 2023)を、画像分類にはTIMM (Wightman, 2019)を使用する。

データセット。 テキスト生成モデルは4つのゼロショットタスクで評価される:WikiText (Merity et al., 2022)でモデルの予測信頼度を評価し、LAMBADA (Paperno et al., 2016)で文脈理解を評価し、MMLU (Hendrycks et al., 2020)で様々な分野の知識を測定し、Longbench (Bai et al., 2023)で二言語、マルチタスク、長文脈理解能力の包括的評価を行う。 テキストから動画生成モデルはopen-sora (Zheng et al., 2024)のプロンプトセットを用いて評価される。 FluxはMJHQ-30K (Li et al., 2024)で評価される。 TIMMは3つの画像データセットで評価される:ImageNet (Deng et al., 2009)、ImageNet-Sketch (Sketch) (Wang et al., 2019)、ImageNet-Rendition (ImageNet-r) (Hendrycks et al., 2021)である。

評価指標。 テキスト生成モデルについては、WikiTextにはパープレキシティ (ppl.) (Jelinek et al., 1977)を、LAMBADAとMMLUには正確性 (Acc.)を、Longbenchにはスコア (Bai et al., 2023)を使用する。 テキストから動画生成モデルについては、Zhao et al. (2024)に従い、生成された動画の品質を5つの指標で評価する:テキストと動画の整合性を測るCLIPSIMとCLIP-Temp (CLIP-T) (Liu et al., 2024)、動画の美的品質と技術的品質をそれぞれ評価する(VQA-a)と(VQA-t)、時間的一貫性を評価するFlow-score (FScore) (Wu et al., 2023)である。 テキストから画像生成モデルについては、生成された画像をMJHQ-30Kデータセットの画像と3つの側面で比較する:忠実度評価にはFID (Heusel et al., 2017)とsFID (Salimans et al., 2016)を、テキストと画像の整合性にはClipscore (CLIP) (Hessel et al., 2021)を、人間の選好にはImageReward (IR) (Xu et al., 2024)を使用する。 TIMMについては、正確性を使用する。

実装の詳細。 我々はCUDAを用いてアテンションカーネルを実装し、RTX4090とL20 GPUを搭載したUbuntu 22.04サーバーで実験を行う。

ベースライン。 (1) SmoothAttn。Qserve (Lin et al., 2024)に従い、行列Q,K𝑄𝐾Q,Kitalic_Q , italic_Kに対して平滑化因子α=0.5𝛼0.5\alpha=0.5italic_α = 0.5でスムースクォンタイゼーションを適用する。 (2) HadmdAttn。Quarot (Ashkboos et al., 2024)に従い、INT4量子化の前に行列Q,K𝑄𝐾Q,Kitalic_Q , italic_Kにランダムアダマール変換を適用する。

SageAttn-mixのハイパーパラメータ。 我々はξ𝜉\xiitalic_ξとして、Llama2に50%、CogvideoXに60%、Open-Soraに25%、Llama3.1GLM4に30%、FluxTIMMに0%を使用する。

Refer to caption
図14: CogvideoXからの比較例、プロンプトはopen-soraプロンプトセットからサンプリングされている。
表9: 異なる平滑化手法を用いた全層にわたる平均精度
Method CosSim \uparrow Relative L1 \downarrow RMSE \downarrow
None 80.04% 0.3906 0.2223
HadmdAttn 79.77% 0.3782 0.2180
SmoothAttn 90.21% 0.3383 0.1952
Smooth K (Ours) 98.07% 0.1493 0.0743
Smooth Q (Ours) 98.30% 0.1250 0.0712
SageAttn2-4b 99.46% 0.0648 0.0334
表10: 異なる平滑化手法を用いた全層にわたる最悪精度
Method CosSim \uparrow Relative L1 \downarrow RMSE \downarrow
None 4.83% 0.9979 0.7784
HadmdAttn 4.85% 0.9978 0.7785
SmoothAttn 64.49% 0.9262 0.7204
Smooth K (Ours) 90.86% 0.3565 0.1464
Smooth Q (Ours) 93.10% 0.2989 0.2195
SageAttn2-4b 96.71% 0.1956 0.0779
表11: テキスト、画像、動画生成モデルにおけるエンドツーエンドの指標損失。
Model Attention WikiText (Ppl.) \downarrow Lambda (Acc.) \uparrow MMLU (Acc.) \uparrow Longbench \uparrow
Llama2 Full-Precision 5.823 0.886 0.439 -
HadmdAttn 6.706 0.865 0.355 -
SmoothAttn 6.690 0.871 0.395 -
SageAttn2-4b 6.018 0.886 0.436 -
SageAttn2-mix 5.883 0.883 0.431 -
Llama3.1 Full-Precision 6.013 0.815 0.635 49.40
HadmdAttn 7.661 0.756 0.502 44.62
SmoothAttn 7.087 0.788 0.551 43.76
SageAttn2-4b 6.219 0.808 0.617 48.61
SageAttn2-mix 6.131 0.816 0.629 49.01
GLM4 Full-Precision 7.241 0.432 0.743 49.78
HadmdAttn 7.932 0.435 0.676 46.27
SmoothAttn 8.841 0.442 0.599 43.10
SageAttn2-4b 7.341 0.435 0.732 49.06
SageAttn2-mix 7.303 0.434 0.737 49.77
Model Attention CLIPSIM \uparrow CLIP-T \uparrow VQA-a \uparrow VQA-t \uparrow FScore \uparrow
CogvideoX Full-Precision 0.1836 0.9975 77.605 75.360 3.006
HadmdAttn 0.1742 0.9877 36.028 23.786 0.550
SmoothAttn 0.1763 0.9870 37.444 42.184 0.792
SageAttn2-4b 0.1813 0.9969 77.276 75.147 2.070
SageAttn2-mix 0.1816 0.9976 75.686 78.600 2.884
Open-Sora Full-Precision 0.1831 0.9996 46.713 59.553 0.368
SageAttn2-4b 0.1821 0.9994 42.270 55.965 0.364
SageAttn2-mix 0.1814 0.9994 44.509 59.097 0.383
        Model         Attention        FID \downarrow        sFID \downarrow        CLIP \uparrow        IR \uparrow
        Flux         Full-Precision        11.303        17.603        32.603        0.9745
        HadmdAttn        11.163        17.693        32.592        0.9638
        SmoothAttn        10.941        18.098        32.582        0.9613
        SageAttn2-4b        10.563        17.052        32.631        0.9747
Model Attention ImageNet (Acc.) \uparrow Sketch (Acc.) \uparrow ImageNet-r (Acc.) \uparrow
TIMM Full-Precision 84.79% 45.32% 59.55%
SageAttn2-4b 84.67% 45.07% 59.11%

4.2 Speed and Accuracy of Kernels

速度。我々は、headdim=64、headdim=128、およびheaddim=256の構成を用いて、因果マスクの有無両方の場合で、SageAttention2の速度をベースラインと比較する実験を行った(Vaswani, 2017)。具体的には、図8、図9、および図10は、RTX4090上での様々なシーケンス長におけるSageAttention2とベースラインの速度を示している。これらの結果は、SageAttention2が最大485 TOPSを達成し、FlashAttention2より3.1倍速く、xformersより5.4倍速いことを示している。 図11、図12、および図13は、L20 GPU上での結果を示しており、SageAttention2が最大288 TOPSを達成し、FlashAttention2より2.7倍速く、xformersより4.6倍速いことを示している。

精度。10および表10は、CogvideoXのすべての層にわたるINT4 Q,K𝑄𝐾Q,Kitalic_Q , italic_KおよびFP8 P,V𝑃𝑉P,Vitalic_P , italic_Vを用いた異なる手法の平均精度と最悪精度を示している。結果は、SageAttn-4bの精度が他のベースラインよりも優れていることを示している。

Refer to caption
図15: 異なる割合のSageAttn-8bを使用したエンドツーエンドのパフォーマンス。

4.3 End-to-end Performance

メトリクスの損失。 我々は、フル精度のアテンションと比較して、SageAttention2を使用した様々なモデルのエンドツーエンドメトリクスを評価した。 詳細な評価結果は、Llama2Llama3.1GLM4CogvideoXOpen-SoraFlux、およびTIMMについて、表11に示されている。結果は、SageAttn-4bがすべてのベースラインを上回り、すべてのモデルにおいてエンドツーエンドの精度のほとんどを維持していることを示している。さらに、適応的量子化技術を用いることで、SageAttn-mixはフル精度のアテンションに匹敵する性能を達成している。

4.4 Ablation Study

適応的量子化。 適応的量子化におけるSageAttn-8bの異なる比率の使用の影響を分析するために、表15Llama3.1のパープレキシティの変化を様々なSageAttn-8bの比率で示している。SageAttn-4bのみを使用した場合でも、全体的なエンドツーエンドの表現は十分に良好であり、使用するSageAttn-8bの比率が高いほど、精度はフル精度のアテンションを使用した場合に近づくことが観察できる。

Q, Kの平滑化のオーバーヘッド。 Q,K𝑄𝐾Q,Kitalic_Q , italic_Kの平滑化のオーバーヘッドには、QQm𝑄subscript𝑄𝑚Q-\overrightarrow{Q_{m}}italic_Q - over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARGKmean(K)𝐾mean𝐾K-\mathrm{mean}(K)italic_K - roman_mean ( italic_K )、およびQmKsubscript𝑄𝑚superscript𝐾top\overrightarrow{Q_{m}}K^{\top}over→ start_ARG italic_Q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_ARG italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPTのみが含まれることに注意されたい。これらの速度オーバーヘッドは、アテンションカーネルの約3.5%を占める。

5 Conclusion and Future Work

我々は、注意機構の効率的かつ正確な4ビット量子化手法であるSageAttention2を提案する。まず、(Q,K)𝑄𝐾(Q,K)( italic_Q , italic_K )をワープレベルの粒度でINT4に量子化し、(P~,V)~𝑃𝑉(\widetilde{P},V)( over~ start_ARG italic_P end_ARG , italic_V )をFP8に量子化することを提案する。次に、行列Q𝑄Qitalic_QV𝑉Vitalic_Vを平滑化する手法を提案し、INT4のQK𝑄𝐾QKitalic_Q italic_KとFP8のPV𝑃𝑉PVitalic_P italic_Vを用いた注意機構の精度を向上させる。さらに、タイムステップと層にわたる量子化精度を分析し、様々なモデルにおけるエンドツーエンドの指標を保証する適応的量子化手法を提案する。我々の実装は、RTX4090上でFlashAttention2とxformersと比較して、それぞれ約3.1倍および5.4倍高速である。広範なテストにより、本稿のアプローチが言語、画像、動画生成を含む様々なモデルにおいてエンドツーエンドの指標を維持することが確認された。

今後の課題。Hopper アーキテクチャ上でのP~V~𝑃𝑉\widetilde{P}Vover~ start_ARG italic_P end_ARG italic_VのFP16アキュムレータを用いたFP8 MatMulの実装とSageAttention2の実装は今後の課題として残されている。

References

  • Ashkboos et al. (2024) Ashkboos, S., Mohtashami, A., Croci, M. L., Li, B., Cameron, P., Jaggi, M., Alistarh, D., Hoefler, T., and Hensman, J. Quarot: Outlier-free 4-bit inference in rotated llms, 2024. URL https://arxiv.org/abs/2404.00456.
  • Bai et al. (2023) Bai, Y., Lv, X., Zhang, J., Lyu, H., Tang, J., Huang, Z., Du, Z., Liu, X., Zeng, A., Hou, L., et al. Longbench: A bilingual, multitask benchmark for long context understanding. arXiv preprint arXiv:2308.14508, 2023.
  • Black Forest Labs (2023) Black Forest Labs. Flux. https://github.com/black-forest-labs/flux, 2023.
  • Chen et al. (2023) Chen, Y., Qian, S., Tang, H., Lai, X., Liu, Z., Han, S., and Jia, J. Longlora: Efficient fine-tuning of long-context large language models. arXiv preprint arXiv:2309.12307, 2023.
  • Choromanski et al. (2020) Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J., Mohiuddin, A., Kaiser, L., et al. Rethinking attention with performers. arXiv preprint arXiv:2009.14794, 2020.
  • Chu et al. (2021) Chu, X., Tian, Z., Wang, Y., Zhang, B., Ren, H., Wei, X., Xia, H., and Shen, C. Twins: Revisiting the design of spatial attention in vision transformers. Advances in neural information processing systems, 34:9355–9366, 2021.
  • Dao (2023) Dao, T. Flashattention-2: Faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023.
  • Dao et al. (2022) Dao, T., Fu, D., Ermon, S., Rudra, A., and Ré, C. Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
  • Deng et al. (2009) Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei-Fei, L. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp.  248–255. Ieee, 2009.
  • Dubey et al. (2024) Dubey, A., Jauhri, A., Pandey, A., Kadian, A., Al-Dahle, A., Letman, A., Mathur, A., Schelten, A., Yang, A., Fan, A., et al. The llama 3 herd of models. arXiv preprint arXiv:2407.21783, 2024.
  • Fu et al. (2024) Fu, T., Huang, H., Ning, X., Zhang, G., Chen, B., Wu, T., Wang, H., Huang, Z., Li, S., Yan, S., Dai, G., Yang, H., and Wang, Y. Moa: Mixture of sparse attention for automatic large language model compression, 2024. URL https://arxiv.org/abs/2406.14909.
  • Gao et al. (2024) Gao, Y., Zeng, Z., Du, D., Cao, S., So, H. K.-H., Cao, T., Yang, F., and Yang, M. Seerattention: Learning intrinsic sparse attention in your llms, 2024. URL https://arxiv.org/abs/2410.13276.
  • gkamradt (2023) gkamradt. Llmtest needle in a haystack - pressure testing llms. https://github.com/gkamradt/LLMTest_NeedleInAHaystack, 2023.
  • GLM et al. (2024) GLM, T., Zeng, A., Xu, B., Wang, B., Zhang, C., Yin, D., Rojas, D., Feng, G., Zhao, H., Lai, H., Yu, H., Wang, H., Sun, J., Zhang, J., Cheng, J., Gui, J., Tang, J., Zhang, J., Li, J., Zhao, L., Wu, L., Zhong, L., Liu, M., Huang, M., Zhang, P., Zheng, Q., Lu, R., Duan, S., Zhang, S., Cao, S., Yang, S., Tam, W. L., Zhao, W., Liu, X., Xia, X., Zhang, X., Gu, X., Lv, X., Liu, X., Liu, X., Yang, X., Song, X., Zhang, X., An, Y., Xu, Y., Niu, Y., Yang, Y., Li, Y., Bai, Y., Dong, Y., Qi, Z., Wang, Z., Yang, Z., Du, Z., Hou, Z., and Wang, Z. Chatglm: A family of large language models from glm-130b to glm-4 all tools, 2024.
  • Hendrycks et al. (2020) Hendrycks, D., Burns, C., Basart, S., Zou, A., Mazeika, M., Song, D., and Steinhardt, J. Measuring massive multitask language understanding. 2020.
  • Hendrycks et al. (2021) Hendrycks, D., Basart, S., Mu, N., Kadavath, S., Wang, F., Dorundo, E., Desai, R., Zhu, T., Parajuli, S., Guo, M., Song, D., Steinhardt, J., and Gilmer, J. The many faces of robustness: A critical analysis of out-of-distribution generalization. ICCV, 2021.
  • Hessel et al. (2021) Hessel, J., Holtzman, A., Forbes, M., Le Bras, R., and Choi, Y. Clipscore: A reference-free evaluation metric for image captioning. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pp.  7514–7528, 2021.
  • Heusel et al. (2017) Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S. Gans trained by a two time-scale update rule converge to a local nash equilibrium. Advances in neural information processing systems, 30, 2017.
  • Jelinek et al. (1977) Jelinek, F., Mercer, R. L., Bahl, L. R., and Baker, J. K. Perplexity—a measure of the difficulty of speech recognition tasks. The Journal of the Acoustical Society of America, 62(S1):S63–S63, 1977.
  • Jiang et al. (2024) Jiang, H., Li, Y., Zhang, C., Wu, Q., Luo, X., Ahn, S., Han, Z., Abdi, A. H., Li, D., Lin, C.-Y., et al. Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention. arXiv preprint arXiv:2407.02490, 2024.
  • Katharopoulos et al. (2020) Katharopoulos, A., Vyas, A., Pappas, N., and Fleuret, F. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pp.  5156–5165. PMLR, 2020.
  • Lefaudeux et al. (2022) Lefaudeux, B., Massa, F., Liskovich, D., Xiong, W., Caggiano, V., Naren, S., Xu, M., Hu, J., Tintore, M., Zhang, S., Labatut, P., Haziza, D., Wehrstedt, L., Reizenstein, J., and Sizov, G. xformers: A modular and hackable transformer modelling library. https://github.com/facebookresearch/xformers, 2022.
  • Li et al. (2024) Li, D., Kamko, A., Akhgari, E., Sabet, A., Xu, L., and Doshi, S. Playground v2.5: Three insights towards enhancing aesthetic quality in text-to-image generation, 2024.
  • (24) Li, K., Wang, Y., Peng, G., Song, G., Liu, Y., Li, H., and Qiao, Y. Uniformer: Unified transformer for efficient spatial-temporal representation learning. In International Conference on Learning Representations.
  • Lin et al. (2024) Lin, Y., Tang, H., Yang, S., Zhang, Z., Xiao, G., Gan, C., and Han, S. Qserve: W4a8kv4 quantization and system co-design for efficient llm serving, 2024. URL https://arxiv.org/abs/2405.04532.
  • Liu et al. (2024) Liu, Y., Cun, X., Liu, X., Wang, X., Zhang, Y., Chen, H., Liu, Y., Zeng, T., Chan, R., and Shan, Y. Evalcrafter: Benchmarking and evaluating large video generation models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  22139–22149, 2024.
  • Liu et al. (2021) Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., Lin, S., and Guo, B. Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF international conference on computer vision, pp.  10012–10022, 2021.
  • Merity et al. (2022) Merity, S., Xiong, C., Bradbury, J., and Socher, R. Pointer sentinel mixture models. In International Conference on Learning Representations, 2022.
  • Milakov & Gimelshein (2018) Milakov, M. and Gimelshein, N. Online normalizer calculation for softmax. arXiv preprint arXiv:1805.02867, 2018.
  • Paperno et al. (2016) Paperno, D., Kruszewski, G., Lazaridou, A., Pham, N.-Q., Bernardi, R., Pezzelle, S., Baroni, M., Boleda, G., and Fernández, R. The lambada dataset: Word prediction requiring a broad discourse context. In Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp.  1525–1534, 2016.
  • Salimans et al. (2016) Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X. Improved techniques for training gans. Advances in neural information processing systems, 29, 2016.
  • Shah et al. (2024) Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., and Dao, T. Flashattention-3: Fast and accurate attention with asynchrony and low-precision. arXiv preprint arXiv:2407.08608, 2024.
  • Touvron et al. (2023) Touvron, H., Martin, L., Stone, K., Albert, P., Almahairi, A., Babaei, Y., Bashlykov, N., Batra, S., Bhargava, P., Bhosale, S., et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  • Vaswani (2017) Vaswani, A. Attention is all you need. Advances in Neural Information Processing Systems, 2017.
  • Venkataramanan et al. (2023) Venkataramanan, S., Ghodrati, A., Asano, Y. M., Porikli, F., and Habibian, A. Skip-attention: Improving vision transformers by paying less attention. arXiv preprint arXiv:2301.02240, 2023.
  • Wang et al. (2019) Wang, H., Ge, S., Lipton, Z., and Xing, E. P. Learning robust global representations by penalizing local predictive power. Advances in Neural Information Processing Systems, 32, 2019.
  • Wang et al. (2020) Wang, S., Li, B. Z., Khabsa, M., Fang, H., and Ma, H. Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768, 2020.
  • Wightman (2019) Wightman, R. Pytorch image models. https://github.com/rwightman/pytorch-image-models, 2019.
  • Wu et al. (2023) Wu, H., Zhang, E., Liao, L., Chen, C., Hou, J., Wang, A., Sun, W., Yan, Q., and Lin, W. Exploring video quality assessment on user generated contents from aesthetic and technical perspectives. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp.  20144–20154, 2023.
  • Xiao et al. (2024) Xiao, C., Zhang, P., Han, X., Xiao, G., Lin, Y., Zhang, Z., Liu, Z., and Sun, M. Infllm: Training-free long-context extrapolation for llms with an efficient context memory. In First Workshop on Long-Context Foundation Models@ ICML 2024, 2024.
  • Xiao et al. (2023a) Xiao, G., Lin, J., Seznec, M., Wu, H., Demouth, J., and Han, S. Smoothquant: Accurate and efficient post-training quantization for large language models. In International Conference on Machine Learning, pp.  38087–38099. PMLR, 2023a.
  • Xiao et al. (2023b) Xiao, G., Tian, Y., Chen, B., Han, S., and Lewis, M. Efficient streaming language models with attention sinks. arXiv preprint arXiv:2309.17453, 2023b.
  • Xu et al. (2024) Xu, J., Liu, X., Wu, Y., Tong, Y., Li, Q., Ding, M., Tang, J., and Dong, Y. Imagereward: Learning and evaluating human preferences for text-to-image generation. Advances in Neural Information Processing Systems, 36, 2024.
  • Yang et al. (2024) Yang, Z., Teng, J., Zheng, W., Ding, M., Huang, S., Xu, J., Yang, Y., Hong, W., Zhang, X., Feng, G., et al. Cogvideox: Text-to-video diffusion models with an expert transformer. arXiv preprint arXiv:2408.06072, 2024.
  • Yu et al. (2022) Yu, W., Luo, M., Zhou, P., Si, C., Zhou, Y., Wang, X., Feng, J., and Yan, S. Metaformer is actually what you need for vision. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp.  10819–10829, 2022.
  • Zhang et al. (2024) Zhang, J., wei, J., Zhang, P., Zhu, J., and Chen, J. Sageattention: Accurate 8-bit attention for plug-and-play inference acceleration, 2024. URL https://arxiv.org/abs/2410.02367.
  • Zhao et al. (2024) Zhao, T., Fang, T., Liu, E., Rui, W., Soedarmadji, W., Li, S., Lin, Z., Dai, G., Yan, S., Yang, H., Ning, X., and Wang, Y. Vidit-q: Efficient and accurate quantization of diffusion transformers for image and video generation, 2024.
  • Zheng et al. (2024) Zheng, Z., Peng, X., Yang, T., Shen, C., Li, S., Liu, H., Zhou, Y., Li, T., and You, Y. Open-sora: Democratizing efficient video production for all, March 2024. URL https://github.com/hpcaitech/Open-Sora.