微調大模型,AMDMI300X就夠了!跟著這篇博客微調Llama3.1405B,效果媲美H100
隨著 AI 模型的參數量越來越大,對算力的需求也水漲船高。比如最近,Llama-3.1 登上了最強開源大模型的寶座,但超大杯 405B 版本的內存就高達 900 多 GB,這對算力構成了更加苛刻的挑戰。如何降低算力的使用成本和使用門檻,已經成為許多公司尋求突破的關鍵。Felafax 就是其中的一家創業公司,致力于簡化 AI 訓練集群的搭建流程。
Nikhil Sonti 和 Nikhin Sonti 創立了 Felafax,他們的口號是在構建開源 AI 平臺,為下一代 AI 硬件服務,將機器學習的訓練成本降低 30%。與英偉達相比,AMD 的 GPU,尤其是 MI300X 系列,提供了更高的性價比,按每美元計算,其性能表現更為出色。最近,Felafax 的聯合創始人 Nikhil Sonti 發布了一篇博客,詳細分享了如何通過 8 張 AMD MI300X GPU 和 JAX 微調 LLaMA 3.1 405B 模型的方法,所有代碼現已開源。
與英偉達 H100 的比較,來源:TensorWave訓練 LLaMA 405B:性能與可擴展性使用 JAX,可以成功地在 AMD GPU 上訓練 LLaMA 405B 模型。我們使用 LoRA 微調,將所有模型權重和 LoRA 參數都設為 bfloat16,LoRA rank 設為 8,LoRA alpha 設為 16:模型大?。篖LaMA 模型的權重占用了約 800GB 的顯存。LoRA 權重 + 優化器狀態:大約占用了 400GB 的顯存。顯存總使用量:占總顯存的 77%,約 1200GB。限制:由于 405B 模型的規模過大,batch 大小和序列長度的空間有限,使用的 batch size 為 16,序列長度為 64。JIT 編譯:由于空間限制,無法運行 JIT 編譯版本;它可能需要比急切模式稍多的空間。訓練速度:使用 JAX 急切模式,約為 35 tokens / 秒。內存效率:穩定在約 70% 左右。擴展性:在 8 張 GPU 上,使用 JAX 的擴展性接近線性。由于硬件和顯存的限制,我們無法運行 JIT 編譯版本的 405B 模型,整個訓練過程是在 JAX 的急切模式下執行的,因此還有很大的進步空間。下圖中顯示了在一次微調訓練步驟中,8 張 GPU 的顯存利用率和 rocm-smi 輸出:GPU 利用率:
顯存利用率:
rocm-smi 輸出:
此前,Nikhil Sonti 分享過如何將 LLaMA 3.1 從 PyTorch 移植到 JAX。他指出,目前 90% 的大型語言模型(LLM)都運行在 NVIDIA GPU 上,但實際上還有一些同樣強大且性價比更高的替代方案。例如,在 Google TPU 上訓練和部署 Llama 3.1 的成本比 NVIDIA GPU 低約 30%。然而,支持非 NVIDIA 硬件的開發工具較為匱乏。Sonti 最初嘗試使用 PyTorch XLA 在 TPU 上訓練 Llama 3.1,但過程并不順利。XLA 與 PyTorch 的集成不夠完善,缺少一些關鍵的庫(如 bitsandbytes 無法正常運行),同時還遇到了一些難以解決的 HuggingFace 錯誤。為此,他決定調整策略,將 Llama 3.1 從 PyTorch 移植到 JAX,成功解決了這些問題。Sonti 還錄制了詳細的教程視頻,并開源了所有代碼:
在調用 jax.device_put 之后:
加入 LoRALoRA 通過將權重更新分解為低秩矩陣,減少了可訓練參數的數量,這對于微調大型模型特別有效。以下是在 AMD GPU 上微調 Llama 3.1-405 的 LoRA 的要點:將 LoRA 參數(lora_a 和 lora_b)與主模型參數分開。使用 jax.lax.stop_gradient (kernel) 來防止對主模型權重的更新。使用 lax.dot_general 進行快速、精確控制的矩陣運算。LoRA 輸出在添加到主輸出之前會被縮放為 (self.lora_alpha/self.lora_rank)。LoRADense 層在此設定一個自定義的 LoRADense 層,該層集成了 LoRA 參數:class LoRADense (nn.Module): features: int lora_rank: int = 8 lora_alpha: float = 16.0@nn.compactdef __call__(self, inputs: Any) -> Any:# Original kernel parameter (frozen) kernel = self.param ('kernel', ...) y = lax.dot_general (inputs, jax.lax.stop_gradient (kernel), ...)# LoRA parameters (trainable) lora_a = self.variable ('lora_params', 'lora_a', ..., ...) lora_b = self.variable ('lora_params', 'lora_b', ..., ...)# Compute LoRA output lora_output = lax.dot_general (inputs, lora_a.value, ...) lora_output = lax.dot_general (lora_output, lora_b.value, ...)# Combine original output with LoRA modifications y += (self.lora_alpha/self.lora_rank) * lora_output return y.astype (self.dtype)分片 LoRA 參數為了高效地在設備之間分配 LoRA 參數,我們也通過 JAX 設定了分片規則,這確保了 LoRA 參數與主模型參數的分片一致,優化了內存使用和計算效率。LoRA A matrices (lora_a)LoRA A 矩陣(lora_a)分片規則:PS ("fsdp", "mp")可視化結果:如下圖所示,lora_a 參數被分片為 (8, 1),這意味著第一個軸在 8 個設備上進行分片("fsdp" 軸),而第二個軸未進行分片。
LoRA B 矩陣(lora_b)分片規則:PS ("mp", "fsdp")可視化結果:如下圖所示,lora_b 參數被分片為 (1, 8),這意味著第二個軸在 8 個設備上進行分片(fsdp 軸),而第一個軸未進行分片。
這種分片策略優化了參數的分配,減少了通信開銷,并在訓練過程中增強了并行性。它確保每個設備僅持有一部分 LoRA 參數,使得大模型如 LLaMA 405B 的高效擴展成為可能。僅更新 LoRA 參數為了優化訓練,在微調 LLaMA 405B 模型,只計算 LoRA 參數的梯度,保持主模型參數不變。這個方法減少了內存使用,并加速了訓練,因為只更新較少的參數。可以移步 GitHub 倉庫,查看實現細節。在訓練過程中,每一步都涉及將一批輸入數據通過模型進行處理。由于只有 LoRA 參數是可訓練的,因此模型的預測和計算的損失僅依賴于這些參數,然后對 LoRA 參數進行反向傳播。只更新這些參數簡化了訓練過程,使得在多個 GPU 上高效微調像 LLaMA 405B 這樣的大型模型成為可能。更多研究細節,請參考原博客。









相關推薦
- 免責聲明
- 本文所包含的觀點僅代表作者個人看法,不代表新火種的觀點。在新火種上獲取的所有信息均不應被視為投資建議。新火種對本文可能提及或鏈接的任何項目不表示認可。 交易和投資涉及高風險,讀者在采取與本文內容相關的任何行動之前,請務必進行充分的盡職調查。最終的決策應該基于您自己的獨立判斷。新火種不對因依賴本文觀點而產生的任何金錢損失負任何責任。