首頁 > AI資訊 > 最新資訊 > PyTorch官方認可!斯坦福博士新作:長上下文LLM推理速度提8倍

PyTorch官方認可!斯坦福博士新作:長上下文LLM推理速度提8倍

新火種    2023-11-01

這兩天,FlashAttention團隊推出了新作:

一種給Transformer架構大模型推理加速的新方法,最高可提速8倍

該方法尤其造福于長上下文LLM,在64k長度的CodeLlama-34B上通過了驗證。

甚至得到了PyTorch官方認可

如果你之前有所關注,就會記得用FlashAttention給大模型加速效果真的很驚艷。

不過它僅限于訓練階段。

因此,這一新成果一出,就有網友表示:

等推理加速等了好久,終于來了。

據介紹,這個新方法也是在FlashAttention的基礎之上衍生而出,主要思想也不復雜:

用并行操作盡快加載Key和Value緩存,然后分別重新縮放再合并結果,最終獲得推理速度上的大幅提升。

提速8倍的長上下文推理方法來了

該方法被命名為Flash-Decoding

背景與動機

根據作者介紹:

LLM的推理(即“解碼”)過程是迭代的,即一次生成一個token,組成一個完整句子需要n個token以及n次前向傳遞。

不過,由于我們可以緩存之前計算出來的token,所以單個生成步驟并不總是依賴于上下文長度。

但有一個操作例外:注意力 (attention),它不能隨著上下文長度靈活擴展。

鑒于長上下文已成趨勢,比如目前最大的開源LLM已達100k(CodeLlama),我們不得不注意到attention在大模型推理過程中浪費了太多時間,時間就是金錢。

更別提attention在batch size上進行擴展時,即使模型上下文相對較短,它也可能成為性能瓶頸(因為模型要讀取的內存量與batch size成比例,而它僅取決于模型其余部分的大小)。

怎么破解?

不可復用的FlashAttention優化

模型在推理也就是解碼過程中,為了計算softmax(queries @keys.transpose)@values這兩個值,生成的每個新token都需要關注先前的所有token。

團隊先前的工作FlashAttention,已經在訓練階段對此操作進行了優化。

當時,FlashAttention解決的主要瓶頸是讀寫中間結果的內存帶寬(例如,Q @ K^T)。

然而,在推理階段,我們要面對的瓶頸變了,導致FlashAttention所做的優化并不能直接拿過來應用。

具體而言:

在階段階段,FlashAttention在batch size和查詢長度維度上進行并行化。

在推理階段,查詢長度通常為1,這意味著如果batch size小于GPU上的流式多處理器數量(例如,A100為108),該操作將僅使用GPU的一小部分。

這對于長上下文情況尤甚,因為長上下文需要較小的batch size才能適應GPU內存。

所以,結果就是,當batch size為1時,FlashAttention將只占用不足1%的GPU,非常不劃算。

當然,你可能會說,不用FlashAttention也行,用矩陣乘法原語來完注意力操作。

不過,作者指出,這種情況又會完全占用GPU,并啟動非常多的寫入和讀取中間結果的內核,也不是最佳辦法。

Flash-Decoding誕生

最終,基于以上考量,作者在FlashAttention的基礎上,添加了一個新的并行化緯度:key和value序列長度

這個方法(即Flash-Decoding)結合上述兩種方法的優點:

與FlashAttention一樣,它在全局內存中存儲的額外數據非常少,但只要上下文長度足夠大,即使batch size很小,它也可以充分利用GPU。

詳細來看,Flash-Decoding一共分為三個步驟

1、先將key和value值分成更小的塊。

2、用FlashAttention并行計算每塊分割的查詢注意力。并為每行和每塊分割寫入一個額外標量:注意力值的log-sum-exp。

3、最后,通過減少所有分割來計算實際輸出,使用log-sum-exp來scale每塊分割的貢獻。

作者指出,由于attention/softmax可以迭代計算,以上所有操作均可行。

并且在Flash-Decoding中,ttention/softmax既可以在分割塊內,也可以跨分割塊來執行最終的縮減,只不過后者可縮減的步驟很少。

而在實際操作中,步驟1不涉及任何GPU操作,因為key和value塊是完整的張量視圖。然后由2個獨立的內核分別執行步驟2和3。

最高提速8倍

驗證環節,作者在CodeLLaMa-34b(架構與Llama 2相同)上對其解碼吞吐量進行了基準測試。

具體以tok/s為單位,測量了512到64k序列長度下的解碼速度(上限為從內存中讀取整個模型以及KV緩存所需的時間),并和多種計算注意力的方法進行對比,包括:

Pytorch,使用純PyTorch原語運行注意力

FlashAttention v2

FasterTransformer:使用FasterTransformer注意力內核

最終,Flash-Decoding最高可將長序列解碼速度提升8倍,并比其他方法具 有更好的擴展性(受長度影響較小)

此外,作者還在A100上對各種序列長度和batch size的縮放多頭注意力進行了微基準測試。

結果顯示,當序列長度擴展到64k時,Flash-Decoding實現了幾乎恒定的運行時間

如何使用?

以下是Flash-Decoding的獲取途徑,戳文末官方博客即可找到地址:

FlashAttention包,2.2版本及以上

xFormers包,0.0.22版本及以上

調度程序將根據問題的大小自動使用Flash-Decoding或 FlashAttention方法。

團隊介紹

目前Flash-Decoding還沒出論文,但作者團隊已透露,這次不再是Tri Dao“單打獨斗”,不過一作仍然是他

Tri Dao今年博士畢業于斯坦福,7月份加盟大模型創業公司Together AI擔任首席科學家。

明年9月將上任普林斯頓大學助理教授,他是FlashAttention v1和v2的主要作者。

剩下三位作者分別是:

Daniel Haziza,Facebook AI Research研究工程師,主要負責xformers(用于訓練加速的開源框架);

Francisco Massa,同Facebook AI Research研究工程師, 主要從事PyTorch相關工作;

Grigory Sizov,Meta機器學習工程師,主要工作是優化GPU上的LLM推理和其他AI工作負載,為PyTorch生態做出過貢獻。

相關推薦
免責聲明
本文所包含的觀點僅代表作者個人看法,不代表新火種的觀點。在新火種上獲取的所有信息均不應被視為投資建議。新火種對本文可能提及或鏈接的任何項目不表示認可。 交易和投資涉及高風險,讀者在采取與本文內容相關的任何行動之前,請務必進行充分的盡職調查。最終的決策應該基于您自己的獨立判斷。新火種不對因依賴本文觀點而產生的任何金錢損失負任何責任。

熱門文章